jaxion 0.0.5__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxion-0.0.5 → jaxion-0.0.7}/PKG-INFO +3 -3
- {jaxion-0.0.5 → jaxion-0.0.7}/README.md +2 -2
- jaxion-0.0.5/jaxion/params_default.json → jaxion-0.0.7/jaxion/defaults.json +12 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/gravity.py +2 -2
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/particles.py +104 -1
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/simulation.py +140 -48
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/utils.py +1 -1
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/visualization.py +8 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/PKG-INFO +3 -3
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/SOURCES.txt +1 -1
- {jaxion-0.0.5 → jaxion-0.0.7}/pyproject.toml +1 -1
- {jaxion-0.0.5 → jaxion-0.0.7}/tests/test_examples.py +6 -6
- {jaxion-0.0.5 → jaxion-0.0.7}/LICENSE +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/__init__.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/analysis.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/constants.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/cosmology.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/hydro.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/quantum.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/dependency_links.txt +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/requires.txt +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/top_level.txt +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/requirements.txt +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/setup.cfg +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/tests/test_analysis.py +0 -0
- {jaxion-0.0.5 → jaxion-0.0.7}/tests/test_cosmology.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxion
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: A differentiable simulation library for fuzzy dark matter in JAX
|
|
5
5
|
Author-email: Philip Mocz <philip.mocz@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -68,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
|
|
|
68
68
|
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
##
|
|
71
|
+
## Install Jaxion
|
|
72
72
|
|
|
73
73
|
Install with:
|
|
74
74
|
|
|
@@ -87,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
|
|
|
87
87
|
|
|
88
88
|
## Examples
|
|
89
89
|
|
|
90
|
-
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
90
|
+
Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
|
|
91
91
|
|
|
92
92
|
<p align="center">
|
|
93
93
|
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
|
|
@@ -45,7 +45,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
|
|
|
45
45
|
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
##
|
|
48
|
+
## Install Jaxion
|
|
49
49
|
|
|
50
50
|
Install with:
|
|
51
51
|
|
|
@@ -64,7 +64,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
|
|
|
64
64
|
|
|
65
65
|
## Examples
|
|
66
66
|
|
|
67
|
-
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
67
|
+
Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
|
|
68
68
|
|
|
69
69
|
<p align="center">
|
|
70
70
|
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
|
|
@@ -48,6 +48,10 @@
|
|
|
48
48
|
"default": 1.0,
|
|
49
49
|
"description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
|
|
50
50
|
},
|
|
51
|
+
"safety_factor": {
|
|
52
|
+
"default": 1.0,
|
|
53
|
+
"description": "safety factor for time stepping."
|
|
54
|
+
},
|
|
51
55
|
"adaptive": {
|
|
52
56
|
"default": false,
|
|
53
57
|
"description": "switch on for adaptive time stepping."
|
|
@@ -75,6 +79,10 @@
|
|
|
75
79
|
"m_22": {
|
|
76
80
|
"default": 1.0,
|
|
77
81
|
"description": "axion mass [10^{-22} eV]."
|
|
82
|
+
},
|
|
83
|
+
"f_15": {
|
|
84
|
+
"default": 0.0,
|
|
85
|
+
"description": "self-interaction strength [10^{15} GeV]."
|
|
78
86
|
}
|
|
79
87
|
},
|
|
80
88
|
"hydro": {
|
|
@@ -91,6 +99,10 @@
|
|
|
91
99
|
"particle_mass": {
|
|
92
100
|
"default": 1.0,
|
|
93
101
|
"description": "particle mass [M_sun]."
|
|
102
|
+
},
|
|
103
|
+
"accrete_gas": {
|
|
104
|
+
"default": false,
|
|
105
|
+
"description": "switch on to accrete gas."
|
|
94
106
|
}
|
|
95
107
|
},
|
|
96
108
|
"cosmology": {
|
|
@@ -5,6 +5,6 @@ import jaxdecomp as jd
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
|
|
8
|
-
|
|
9
|
-
V = jnp.real(jd.fft.pifft3d(
|
|
8
|
+
V_hat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
|
|
9
|
+
V = jnp.real(jd.fft.pifft3d(V_hat))
|
|
10
10
|
return V
|
|
@@ -18,7 +18,7 @@ def get_cic_indices_and_weights(pos, dx, resolution):
|
|
|
18
18
|
return i, ip1, weight_i, weight_ip1
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def bin_particles(pos, dx, resolution,
|
|
21
|
+
def bin_particles(pos, m_particles, dx, resolution, multiple_masses):
|
|
22
22
|
"""Bin the particles into the grid using cloud-in-cell weights."""
|
|
23
23
|
nx = resolution
|
|
24
24
|
n_particle = pos.shape[0]
|
|
@@ -27,6 +27,10 @@ def bin_particles(pos, dx, resolution, m_particle):
|
|
|
27
27
|
|
|
28
28
|
def deposit_particle(s, rho):
|
|
29
29
|
"""Deposit the particle mass into the grid."""
|
|
30
|
+
if multiple_masses:
|
|
31
|
+
m_particle = m_particles[s]
|
|
32
|
+
else:
|
|
33
|
+
m_particle = m_particles
|
|
30
34
|
fac = m_particle / (dx * dx * dx)
|
|
31
35
|
rho = rho.at[i[s, 0], i[s, 1], i[s, 2]].add(
|
|
32
36
|
w_i[s, 0] * w_i[s, 1] * w_i[s, 2] * fac
|
|
@@ -112,3 +116,102 @@ def particles_drift(pos, vel, dt, box_size):
|
|
|
112
116
|
pos = jnp.mod(pos, jnp.array([box_size, box_size, box_size]))
|
|
113
117
|
|
|
114
118
|
return pos
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def particles_accrete_gas(mass, rho, pos, G, sound_speed, dx, dt):
|
|
122
|
+
"""Accrete gas onto particles (Bondi)."""
|
|
123
|
+
n_particle = pos.shape[0]
|
|
124
|
+
resolution = rho.shape[0]
|
|
125
|
+
i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
|
|
126
|
+
d_mass = jnp.zeros_like(mass)
|
|
127
|
+
d_rho = jnp.zeros_like(rho)
|
|
128
|
+
lam = jnp.exp(1.5) / 4.0 # ≈ 1.12
|
|
129
|
+
vol = dx**3
|
|
130
|
+
|
|
131
|
+
def accrete(s, deltas):
|
|
132
|
+
"""Deposit the particle mass into the grid."""
|
|
133
|
+
d_mass, d_rho = deltas
|
|
134
|
+
dM_fac = dt * 4.0 * jnp.pi * lam * (G * mass[s]) ** 2 / sound_speed**3
|
|
135
|
+
# dM = dM_fac * rho
|
|
136
|
+
|
|
137
|
+
dm = w_i[s, 0] * w_i[s, 1] * w_i[s, 2] * dM_fac * rho[i[s, 0], i[s, 1], i[s, 2]]
|
|
138
|
+
d_rho = d_rho.at[i[s, 0], i[s, 1], i[s, 2]].add(-dm / vol)
|
|
139
|
+
d_mass = d_mass.at[s].add(dm)
|
|
140
|
+
|
|
141
|
+
dm = (
|
|
142
|
+
w_ip1[s, 0]
|
|
143
|
+
* w_i[s, 1]
|
|
144
|
+
* w_i[s, 2]
|
|
145
|
+
* dM_fac
|
|
146
|
+
* rho[ip1[s, 0], i[s, 1], i[s, 2]]
|
|
147
|
+
)
|
|
148
|
+
d_rho = d_rho.at[ip1[s, 0], i[s, 1], i[s, 2]].add(-dm / vol)
|
|
149
|
+
d_mass = d_mass.at[s].add(dm)
|
|
150
|
+
|
|
151
|
+
dm = (
|
|
152
|
+
w_i[s, 0]
|
|
153
|
+
* w_ip1[s, 1]
|
|
154
|
+
* w_i[s, 2]
|
|
155
|
+
* dM_fac
|
|
156
|
+
* rho[i[s, 0], ip1[s, 1], i[s, 2]]
|
|
157
|
+
)
|
|
158
|
+
d_rho = d_rho.at[i[s, 0], ip1[s, 1], i[s, 2]].add(-dm / vol)
|
|
159
|
+
d_mass = d_mass.at[s].add(dm)
|
|
160
|
+
|
|
161
|
+
dm = (
|
|
162
|
+
w_i[s, 0]
|
|
163
|
+
* w_i[s, 1]
|
|
164
|
+
* w_ip1[s, 2]
|
|
165
|
+
* dM_fac
|
|
166
|
+
* rho[i[s, 0], i[s, 1], ip1[s, 2]]
|
|
167
|
+
)
|
|
168
|
+
d_rho = d_rho.at[i[s, 0], i[s, 1], ip1[s, 2]].add(-dm / vol)
|
|
169
|
+
d_mass = d_mass.at[s].add(dm)
|
|
170
|
+
|
|
171
|
+
dm = (
|
|
172
|
+
w_ip1[s, 0]
|
|
173
|
+
* w_ip1[s, 1]
|
|
174
|
+
* w_i[s, 2]
|
|
175
|
+
* dM_fac
|
|
176
|
+
* rho[ip1[s, 0], ip1[s, 1], i[s, 2]]
|
|
177
|
+
)
|
|
178
|
+
d_rho = d_rho.at[ip1[s, 0], ip1[s, 1], i[s, 2]].add(-dm / vol)
|
|
179
|
+
d_mass = d_mass.at[s].add(dm)
|
|
180
|
+
|
|
181
|
+
dm = (
|
|
182
|
+
w_ip1[s, 0]
|
|
183
|
+
* w_i[s, 1]
|
|
184
|
+
* w_ip1[s, 2]
|
|
185
|
+
* dM_fac
|
|
186
|
+
* rho[ip1[s, 0], i[s, 1], ip1[s, 2]]
|
|
187
|
+
)
|
|
188
|
+
d_rho = d_rho.at[ip1[s, 0], i[s, 1], ip1[s, 2]].add(-dm / vol)
|
|
189
|
+
d_mass = d_mass.at[s].add(dm)
|
|
190
|
+
|
|
191
|
+
dm = (
|
|
192
|
+
w_i[s, 0]
|
|
193
|
+
* w_ip1[s, 1]
|
|
194
|
+
* w_ip1[s, 2]
|
|
195
|
+
* dM_fac
|
|
196
|
+
* rho[i[s, 0], ip1[s, 1], ip1[s, 2]]
|
|
197
|
+
)
|
|
198
|
+
d_rho = d_rho.at[i[s, 0], ip1[s, 1], ip1[s, 2]].add(-dm / vol)
|
|
199
|
+
d_mass = d_mass.at[s].add(dm)
|
|
200
|
+
|
|
201
|
+
dm = (
|
|
202
|
+
w_ip1[s, 0]
|
|
203
|
+
* w_ip1[s, 1]
|
|
204
|
+
* w_ip1[s, 2]
|
|
205
|
+
* dM_fac
|
|
206
|
+
* rho[ip1[s, 0], ip1[s, 1], ip1[s, 2]]
|
|
207
|
+
)
|
|
208
|
+
d_rho = d_rho.at[ip1[s, 0], ip1[s, 1], ip1[s, 2]].add(-dm / vol)
|
|
209
|
+
d_mass = d_mass.at[s].add(dm)
|
|
210
|
+
|
|
211
|
+
return d_mass, d_rho
|
|
212
|
+
|
|
213
|
+
d_mass, d_rho = jax.lax.fori_loop(0, n_particle, accrete, (d_mass, d_rho))
|
|
214
|
+
mass = mass + d_mass
|
|
215
|
+
rho = rho + d_rho
|
|
216
|
+
|
|
217
|
+
return mass, rho
|
|
@@ -9,7 +9,12 @@ from .constants import constants
|
|
|
9
9
|
from .quantum import quantum_kick, quantum_drift, quantum_velocity
|
|
10
10
|
from .gravity import calculate_gravitational_potential
|
|
11
11
|
from .hydro import hydro_fluxes, hydro_accelerate
|
|
12
|
-
from .particles import
|
|
12
|
+
from .particles import (
|
|
13
|
+
particles_accelerate,
|
|
14
|
+
particles_drift,
|
|
15
|
+
particles_accrete_gas,
|
|
16
|
+
bin_particles,
|
|
17
|
+
)
|
|
13
18
|
from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
|
|
14
19
|
from .utils import (
|
|
15
20
|
set_up_parameters,
|
|
@@ -62,9 +67,10 @@ class Simulation:
|
|
|
62
67
|
self.params["physics"]["hydro"]
|
|
63
68
|
or self.params["physics"]["particles"]
|
|
64
69
|
or self.params["physics"]["external_potential"]
|
|
70
|
+
or self.params["quantum"]["f_15"] != 0.0
|
|
65
71
|
):
|
|
66
72
|
raise NotImplementedError(
|
|
67
|
-
"Cosmological hydro/particles/external_potential
|
|
73
|
+
"Cosmological hydro/particles/external_potential/SI is not yet implemented."
|
|
68
74
|
)
|
|
69
75
|
|
|
70
76
|
if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
|
|
@@ -94,6 +100,12 @@ class Simulation:
|
|
|
94
100
|
xones, static_argnums=0, in_shardings=None, out_shardings=sharding
|
|
95
101
|
)
|
|
96
102
|
|
|
103
|
+
# custom functions
|
|
104
|
+
self.custom_kick = None
|
|
105
|
+
self.custom_drift = None
|
|
106
|
+
self.custom_density = None
|
|
107
|
+
self.custom_plot = None
|
|
108
|
+
|
|
97
109
|
# simulation state
|
|
98
110
|
self.state = {}
|
|
99
111
|
self.state["t"] = 0.0
|
|
@@ -119,6 +131,11 @@ class Simulation:
|
|
|
119
131
|
if self.params["physics"]["particles"]:
|
|
120
132
|
self.state["pos"] = jnp.zeros((self.num_particles, 3))
|
|
121
133
|
self.state["vel"] = jnp.zeros((self.num_particles, 3))
|
|
134
|
+
if self.params["particles"]["accrete_gas"]:
|
|
135
|
+
self.state["mass"] = (
|
|
136
|
+
jnp.zeros(self.num_particles)
|
|
137
|
+
+ self.params["particles"]["particle_mass"]
|
|
138
|
+
)
|
|
122
139
|
|
|
123
140
|
if load_from_checkpoint:
|
|
124
141
|
options = ocp.CheckpointManagerOptions()
|
|
@@ -173,6 +190,23 @@ class Simulation:
|
|
|
173
190
|
/ constants["speed_of_light"] ** 2
|
|
174
191
|
)
|
|
175
192
|
|
|
193
|
+
@property
|
|
194
|
+
def scattering_length(self):
|
|
195
|
+
"""
|
|
196
|
+
Return the axion self-interaction scattering length in the simulation (kpc)
|
|
197
|
+
"""
|
|
198
|
+
f_15 = self.params["quantum"]["f_15"]
|
|
199
|
+
if f_15 == 0.0:
|
|
200
|
+
return 0.0
|
|
201
|
+
else:
|
|
202
|
+
f = f_15 * 1.0e24 * constants["electron_volt"]
|
|
203
|
+
sign = 1.0 if f > 0 else -1.0
|
|
204
|
+
hbar = constants["reduced_planck_constant"]
|
|
205
|
+
c = constants["speed_of_light"]
|
|
206
|
+
m = self.axion_mass
|
|
207
|
+
a_s = sign * (hbar * c**3 * m) / (32.0 * jnp.pi * (f**2))
|
|
208
|
+
return a_s
|
|
209
|
+
|
|
176
210
|
@property
|
|
177
211
|
def m_per_hbar(self):
|
|
178
212
|
"""
|
|
@@ -224,6 +258,13 @@ class Simulation:
|
|
|
224
258
|
"""
|
|
225
259
|
return quantum_velocity(self.state["psi"], self.box_size, self.m_per_hbar)
|
|
226
260
|
|
|
261
|
+
@property
|
|
262
|
+
def rho_bar(self):
|
|
263
|
+
"""
|
|
264
|
+
Return the mean density of the simulation
|
|
265
|
+
"""
|
|
266
|
+
return self._calc_rho_bar(self.state)
|
|
267
|
+
|
|
227
268
|
def _calc_rho_bar(self, state):
|
|
228
269
|
rho_bar = 0.0
|
|
229
270
|
if self.params["physics"]["quantum"]:
|
|
@@ -231,15 +272,19 @@ class Simulation:
|
|
|
231
272
|
if self.params["physics"]["hydro"]:
|
|
232
273
|
rho_bar += jnp.mean(state["rho"])
|
|
233
274
|
if self.params["physics"]["particles"]:
|
|
234
|
-
m_particle = self.params["particles"]["particle_mass"]
|
|
235
|
-
n_particles = self.num_particles
|
|
236
275
|
box_size = self.box_size
|
|
237
|
-
|
|
276
|
+
if self.params["particles"]["accrete_gas"]:
|
|
277
|
+
rho_bar += jnp.sum(state["mass"]) / box_size**3
|
|
278
|
+
else:
|
|
279
|
+
m_particle = self.params["particles"]["particle_mass"]
|
|
280
|
+
n_particles = self.num_particles
|
|
281
|
+
rho_bar += m_particle * n_particles / box_size**3
|
|
282
|
+
if self.custom_density is not None:
|
|
283
|
+
rho_bar += jnp.mean(self.custom_density(state))
|
|
238
284
|
return rho_bar
|
|
239
285
|
|
|
240
286
|
def _calc_grav_potential(self, state, k_sq):
|
|
241
287
|
G = constants["gravitational_constant"]
|
|
242
|
-
m_particle = self.params["particles"]["particle_mass"]
|
|
243
288
|
rho_bar = self._calc_rho_bar(state)
|
|
244
289
|
rho_tot = 0.0
|
|
245
290
|
if self.params["physics"]["quantum"]:
|
|
@@ -247,7 +292,19 @@ class Simulation:
|
|
|
247
292
|
if self.params["physics"]["hydro"]:
|
|
248
293
|
rho_tot += state["rho"]
|
|
249
294
|
if self.params["physics"]["particles"]:
|
|
250
|
-
|
|
295
|
+
multiple_masses = self.params["particles"]["accrete_gas"]
|
|
296
|
+
if multiple_masses:
|
|
297
|
+
m_particles = state["mass"]
|
|
298
|
+
rho_tot += bin_particles(
|
|
299
|
+
state["pos"], m_particles, self.dx, self.resolution, multiple_masses
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
m_particles = self.params["particles"]["particle_mass"]
|
|
303
|
+
rho_tot += bin_particles(
|
|
304
|
+
state["pos"], m_particles, self.dx, self.resolution, multiple_masses
|
|
305
|
+
)
|
|
306
|
+
if self.custom_density is not None:
|
|
307
|
+
rho_tot += self.custom_density(state)
|
|
251
308
|
if self.params["physics"]["cosmology"]:
|
|
252
309
|
scale_factor = 1.0 / (1.0 + state["redshift"])
|
|
253
310
|
rho_bar *= scale_factor
|
|
@@ -280,18 +337,32 @@ class Simulation:
|
|
|
280
337
|
|
|
281
338
|
# Simulation parameters
|
|
282
339
|
dx = self.dx
|
|
283
|
-
m_per_hbar = self.m_per_hbar
|
|
284
340
|
box_size = self.box_size
|
|
341
|
+
num_cells = self.resolution**3
|
|
342
|
+
m_per_hbar = self.m_per_hbar
|
|
285
343
|
|
|
286
|
-
|
|
287
|
-
dt_kin =
|
|
344
|
+
safety = self.params["time"]["safety_factor"]
|
|
345
|
+
dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
|
|
288
346
|
t_start = self.params["time"]["start"]
|
|
289
347
|
t_end = self.params["time"]["end"]
|
|
290
348
|
t_span = t_end - t_start
|
|
291
349
|
state["t"] = t_start
|
|
292
350
|
|
|
351
|
+
use_quantum = self.params["physics"]["quantum"]
|
|
352
|
+
use_gravity = self.params["physics"]["gravity"]
|
|
353
|
+
use_hydro = self.params["physics"]["hydro"]
|
|
354
|
+
use_particles = self.params["physics"]["particles"]
|
|
355
|
+
use_cosmology = self.params["physics"]["cosmology"]
|
|
356
|
+
use_external_potential = self.params["physics"]["external_potential"]
|
|
357
|
+
accrete_gas = self.params["particles"]["accrete_gas"]
|
|
358
|
+
save = self.params["output"]["save"]
|
|
359
|
+
use_custom = self.custom_kick is not None or self.custom_drift is not None
|
|
360
|
+
if use_custom:
|
|
361
|
+
custom_kick = self.custom_kick
|
|
362
|
+
custom_drift = self.custom_drift
|
|
363
|
+
|
|
293
364
|
# cosmology
|
|
294
|
-
if
|
|
365
|
+
if use_cosmology:
|
|
295
366
|
z_start = self.params["time"]["start"]
|
|
296
367
|
z_end = self.params["time"]["end"]
|
|
297
368
|
omega_matter = self.params["cosmology"]["omega_matter"]
|
|
@@ -303,6 +374,17 @@ class Simulation:
|
|
|
303
374
|
state["t"] = 0.0
|
|
304
375
|
state["redshift"] = z_start
|
|
305
376
|
|
|
377
|
+
# self-interaction
|
|
378
|
+
a_s = self.scattering_length
|
|
379
|
+
c = constants["speed_of_light"]
|
|
380
|
+
m = self.axion_mass
|
|
381
|
+
si_coeff = None
|
|
382
|
+
si_coeff2 = None
|
|
383
|
+
do_self_interaction = a_s != 0.0
|
|
384
|
+
if do_self_interaction:
|
|
385
|
+
si_coeff = (4.0 * jnp.pi) * (a_s / m) / m_per_hbar**2
|
|
386
|
+
si_coeff2 = (32.0 * jnp.pi**2 / 3.0) * (a_s / (m * c)) ** 2 / m_per_hbar**5
|
|
387
|
+
|
|
306
388
|
# hydro
|
|
307
389
|
c_sound = self.params["hydro"]["sound_speed"]
|
|
308
390
|
|
|
@@ -317,16 +399,14 @@ class Simulation:
|
|
|
317
399
|
k_sq = None
|
|
318
400
|
|
|
319
401
|
# Fourier space variables
|
|
320
|
-
if
|
|
402
|
+
if use_gravity or use_quantum:
|
|
321
403
|
kx, ky, kz = self.kgrid
|
|
322
404
|
k_sq = kx**2 + ky**2 + kz**2
|
|
323
405
|
|
|
324
406
|
# Checkpointer
|
|
325
|
-
if
|
|
407
|
+
if save:
|
|
326
408
|
options = ocp.CheckpointManagerOptions()
|
|
327
|
-
checkpoint_dir =
|
|
328
|
-
os.getcwd(), self.params["output"]["path"]
|
|
329
|
-
)
|
|
409
|
+
checkpoint_dir = os.path.join(os.getcwd(), self.params["output"]["path"])
|
|
330
410
|
path = os.path.join(os.getcwd(), checkpoint_dir)
|
|
331
411
|
if jax.process_index() == 0:
|
|
332
412
|
path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
|
|
@@ -336,80 +416,91 @@ class Simulation:
|
|
|
336
416
|
|
|
337
417
|
def _kick(state, kx, ky, kz, k_sq, dt):
|
|
338
418
|
# Kick (half-step)
|
|
339
|
-
if
|
|
340
|
-
self.params["physics"]["gravity"]
|
|
341
|
-
and self.params["physics"]["external_potential"]
|
|
342
|
-
):
|
|
419
|
+
if use_gravity and use_external_potential:
|
|
343
420
|
V = self._calc_grav_potential(state, k_sq) + state["V_ext"]
|
|
344
|
-
elif
|
|
421
|
+
elif use_gravity:
|
|
345
422
|
V = self._calc_grav_potential(state, k_sq)
|
|
346
|
-
elif
|
|
423
|
+
elif use_external_potential:
|
|
347
424
|
V = state["V_ext"]
|
|
348
425
|
|
|
349
|
-
if
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
426
|
+
if use_gravity or use_external_potential:
|
|
427
|
+
if use_quantum:
|
|
428
|
+
if do_self_interaction:
|
|
429
|
+
rho = jnp.abs(state["psi"]) ** 2
|
|
430
|
+
V_prime = V + si_coeff * rho + si_coeff2 * rho**2
|
|
431
|
+
state["psi"] = quantum_kick(
|
|
432
|
+
state["psi"], V_prime, m_per_hbar, dt
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
|
|
436
|
+
if use_hydro:
|
|
356
437
|
state["vx"], state["vy"], state["vz"] = hydro_accelerate(
|
|
357
438
|
state["vx"], state["vy"], state["vz"], V, kx, ky, kz, dt
|
|
358
439
|
)
|
|
359
|
-
if
|
|
440
|
+
if use_particles:
|
|
360
441
|
state["vel"] = particles_accelerate(
|
|
361
442
|
state["vel"], state["pos"], V, kx, ky, kz, dx, dt
|
|
362
443
|
)
|
|
444
|
+
if use_custom:
|
|
445
|
+
state = custom_kick(state, V, dt)
|
|
446
|
+
|
|
447
|
+
return state
|
|
363
448
|
|
|
364
449
|
def _drift(state, k_sq, dt):
|
|
365
450
|
# Drift (full-step)
|
|
366
|
-
if
|
|
451
|
+
if use_quantum:
|
|
367
452
|
state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
|
|
368
|
-
if
|
|
453
|
+
if use_hydro:
|
|
369
454
|
state["rho"], state["vx"], state["vy"], state["vz"] = hydro_fluxes(
|
|
370
455
|
state["rho"], state["vx"], state["vy"], state["vz"], dt, dx, c_sound
|
|
371
456
|
)
|
|
372
|
-
if
|
|
457
|
+
if use_particles:
|
|
373
458
|
state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
|
|
459
|
+
if use_custom:
|
|
460
|
+
state = custom_drift(state, k_sq, dt)
|
|
461
|
+
if use_hydro and accrete_gas:
|
|
462
|
+
G = constants["gravitational_constant"]
|
|
463
|
+
state["mass"], state["rho"] = particles_accrete_gas(
|
|
464
|
+
state["mass"], state["rho"], state["pos"], G, c_sound, dx, dt
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
return state
|
|
374
468
|
|
|
375
|
-
@jax.jit
|
|
376
469
|
def _update(_, carry):
|
|
377
470
|
# Update the simulation state by one timestep
|
|
378
471
|
# according to a 2nd-order `kick-drift-kick` scheme
|
|
379
472
|
state, kx, ky, kz, k_sq = carry
|
|
380
|
-
_kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
381
|
-
_drift(state, k_sq, dt)
|
|
473
|
+
state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
474
|
+
state = _drift(state, k_sq, dt)
|
|
382
475
|
# update time & redshift
|
|
383
476
|
state["t"] += dt
|
|
384
|
-
if
|
|
477
|
+
if use_cosmology:
|
|
385
478
|
scale_factor = get_next_scale_factor(
|
|
386
|
-
state["redshift"],
|
|
387
|
-
dt,
|
|
388
|
-
self.params["cosmology"]["omega_matter"],
|
|
389
|
-
self.params["cosmology"]["omega_lambda"],
|
|
390
|
-
self.params["cosmology"]["little_h"],
|
|
479
|
+
state["redshift"], dt, omega_matter, omega_lambda, little_h
|
|
391
480
|
)
|
|
392
481
|
state["redshift"] = 1.0 / scale_factor - 1.0
|
|
393
|
-
_kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
482
|
+
state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
394
483
|
|
|
395
484
|
return state, kx, ky, kz, k_sq
|
|
396
485
|
|
|
397
486
|
# save initial state
|
|
398
487
|
if jax.process_index() == 0:
|
|
399
488
|
print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
|
|
400
|
-
if
|
|
489
|
+
if save:
|
|
401
490
|
with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
|
|
402
491
|
json.dump(self.params, f, indent=2)
|
|
403
492
|
async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
|
|
404
493
|
plot_sim(state, checkpoint_dir, 0, self.params)
|
|
494
|
+
if self.custom_plot is not None:
|
|
495
|
+
self.custom_plot(state, checkpoint_dir, 0, self.params)
|
|
405
496
|
async_checkpoint_manager.wait_until_finished()
|
|
406
497
|
|
|
407
498
|
# Simulation Main Loop
|
|
408
499
|
t_start_timer = time.time()
|
|
409
|
-
if
|
|
500
|
+
if save:
|
|
410
501
|
for i in range(1, num_checkpoints + 1):
|
|
411
502
|
carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
|
|
412
|
-
state,
|
|
503
|
+
state, _, _, _, _ = carry
|
|
413
504
|
jax.block_until_ready(state)
|
|
414
505
|
# save state
|
|
415
506
|
async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
|
|
@@ -417,17 +508,18 @@ class Simulation:
|
|
|
417
508
|
elapsed = time.time() - t_start_timer
|
|
418
509
|
est_total = elapsed / i * num_checkpoints
|
|
419
510
|
est_remaining = est_total - elapsed
|
|
420
|
-
num_cells = self.resolution**3
|
|
421
511
|
mcups = (num_cells * (i * nt_sub)) / (elapsed * 1.0e6)
|
|
422
512
|
if jax.process_index() == 0:
|
|
423
513
|
print(
|
|
424
514
|
f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
|
|
425
515
|
)
|
|
426
516
|
plot_sim(state, checkpoint_dir, i, self.params)
|
|
517
|
+
if self.custom_plot is not None:
|
|
518
|
+
self.custom_plot(state, checkpoint_dir, i, self.params)
|
|
427
519
|
async_checkpoint_manager.wait_until_finished()
|
|
428
520
|
else:
|
|
429
521
|
carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
|
|
430
|
-
state,
|
|
522
|
+
state, _, _, _, _ = carry
|
|
431
523
|
jax.block_until_ready(state)
|
|
432
524
|
if jax.process_index() == 0:
|
|
433
525
|
print("Simulation Run Time (s): ", time.time() - t_start_timer)
|
|
@@ -31,7 +31,7 @@ def print_distributed_info():
|
|
|
31
31
|
|
|
32
32
|
def set_up_parameters(user_overwrites):
|
|
33
33
|
# first load the default params
|
|
34
|
-
params_path = importlib.resources.files("jaxion") / "
|
|
34
|
+
params_path = importlib.resources.files("jaxion") / "defaults.json"
|
|
35
35
|
with params_path.open("r", encoding="utf-8") as f:
|
|
36
36
|
params = json.load(f)
|
|
37
37
|
|
|
@@ -82,6 +82,14 @@ def plot_sim(state, checkpoint_dir, i, params):
|
|
|
82
82
|
vmax=vmax,
|
|
83
83
|
extent=[0, nx, 0, nx],
|
|
84
84
|
)
|
|
85
|
+
if params["physics"]["particles"]:
|
|
86
|
+
# draw particles
|
|
87
|
+
box_size = params["domain"]["box_size"]
|
|
88
|
+
sx = (state["pos"][:, 0] / box_size) * nx
|
|
89
|
+
sy = (state["pos"][:, 1] / box_size) * nx
|
|
90
|
+
plt.plot(
|
|
91
|
+
sx, sy, color="red", marker=".", linestyle="None", markersize=5
|
|
92
|
+
)
|
|
85
93
|
ax.set_aspect("equal")
|
|
86
94
|
ax.get_xaxis().set_visible(False)
|
|
87
95
|
ax.get_yaxis().set_visible(False)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxion
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: A differentiable simulation library for fuzzy dark matter in JAX
|
|
5
5
|
Author-email: Philip Mocz <philip.mocz@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -68,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
|
|
|
68
68
|
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
##
|
|
71
|
+
## Install Jaxion
|
|
72
72
|
|
|
73
73
|
Install with:
|
|
74
74
|
|
|
@@ -87,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
|
|
|
87
87
|
|
|
88
88
|
## Examples
|
|
89
89
|
|
|
90
|
-
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
90
|
+
Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
|
|
91
91
|
|
|
92
92
|
<p align="center">
|
|
93
93
|
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
|
|
@@ -12,7 +12,7 @@ def test_tidal_stripping():
|
|
|
12
12
|
)
|
|
13
13
|
assert sim.resolution == 32
|
|
14
14
|
assert sim.state["t"] > 0.0
|
|
15
|
-
assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(
|
|
15
|
+
assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(162.028, rel=rel_tol)
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def test_tidal_stripping_distributed_emulate():
|
|
@@ -22,7 +22,7 @@ def test_tidal_stripping_distributed_emulate():
|
|
|
22
22
|
)
|
|
23
23
|
assert sim.resolution == 32
|
|
24
24
|
assert sim.state["t"] > 0.0
|
|
25
|
-
assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(
|
|
25
|
+
assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(162.028, rel=rel_tol)
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def test_heating_gas():
|
|
@@ -41,13 +41,13 @@ def test_heating_stars():
|
|
|
41
41
|
)
|
|
42
42
|
assert sim.resolution == 32
|
|
43
43
|
assert sim.state["t"] > 0.0
|
|
44
|
-
assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(2574.
|
|
44
|
+
assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(2574.395, rel=rel_tol)
|
|
45
45
|
assert jnp.mean(jnp.abs(sim.state["vel"][:, 0])) == pytest.approx(
|
|
46
|
-
16.
|
|
46
|
+
16.625353, rel=rel_tol
|
|
47
47
|
)
|
|
48
48
|
assert jnp.mean(jnp.abs(sim.state["vel"][:, 1])) == pytest.approx(
|
|
49
|
-
17.
|
|
49
|
+
17.345486, rel=rel_tol
|
|
50
50
|
)
|
|
51
51
|
assert jnp.mean(jnp.abs(sim.state["vel"][:, 2])) == pytest.approx(
|
|
52
|
-
18.
|
|
52
|
+
18.218296, rel=rel_tol
|
|
53
53
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|