jaxion 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {jaxion-0.0.5 → jaxion-0.0.7}/PKG-INFO +3 -3
  2. {jaxion-0.0.5 → jaxion-0.0.7}/README.md +2 -2
  3. jaxion-0.0.5/jaxion/params_default.json → jaxion-0.0.7/jaxion/defaults.json +12 -0
  4. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/gravity.py +2 -2
  5. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/particles.py +104 -1
  6. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/simulation.py +140 -48
  7. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/utils.py +1 -1
  8. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/visualization.py +8 -0
  9. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/PKG-INFO +3 -3
  10. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/SOURCES.txt +1 -1
  11. {jaxion-0.0.5 → jaxion-0.0.7}/pyproject.toml +1 -1
  12. {jaxion-0.0.5 → jaxion-0.0.7}/tests/test_examples.py +6 -6
  13. {jaxion-0.0.5 → jaxion-0.0.7}/LICENSE +0 -0
  14. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/__init__.py +0 -0
  15. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/analysis.py +0 -0
  16. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/constants.py +0 -0
  17. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/cosmology.py +0 -0
  18. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/hydro.py +0 -0
  19. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion/quantum.py +0 -0
  20. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/dependency_links.txt +0 -0
  21. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/requires.txt +0 -0
  22. {jaxion-0.0.5 → jaxion-0.0.7}/jaxion.egg-info/top_level.txt +0 -0
  23. {jaxion-0.0.5 → jaxion-0.0.7}/requirements.txt +0 -0
  24. {jaxion-0.0.5 → jaxion-0.0.7}/setup.cfg +0 -0
  25. {jaxion-0.0.5 → jaxion-0.0.7}/tests/test_analysis.py +0 -0
  26. {jaxion-0.0.5 → jaxion-0.0.7}/tests/test_cosmology.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -68,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
68
68
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
69
69
 
70
70
 
71
- ## Getting started
71
+ ## Install Jaxion
72
72
 
73
73
  Install with:
74
74
 
@@ -87,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
87
87
 
88
88
  ## Examples
89
89
 
90
- Check out the `examples/` directory for demonstrations of using Jaxion.
90
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
91
91
 
92
92
  <p align="center">
93
93
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -45,7 +45,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
45
45
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
46
46
 
47
47
 
48
- ## Getting started
48
+ ## Install Jaxion
49
49
 
50
50
  Install with:
51
51
 
@@ -64,7 +64,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
64
64
 
65
65
  ## Examples
66
66
 
67
- Check out the `examples/` directory for demonstrations of using Jaxion.
67
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
68
68
 
69
69
  <p align="center">
70
70
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -48,6 +48,10 @@
48
48
  "default": 1.0,
49
49
  "description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
50
50
  },
51
+ "safety_factor": {
52
+ "default": 1.0,
53
+ "description": "safety factor for time stepping."
54
+ },
51
55
  "adaptive": {
52
56
  "default": false,
53
57
  "description": "switch on for adaptive time stepping."
@@ -75,6 +79,10 @@
75
79
  "m_22": {
76
80
  "default": 1.0,
77
81
  "description": "axion mass [10^{-22} eV]."
82
+ },
83
+ "f_15": {
84
+ "default": 0.0,
85
+ "description": "self-interaction strength [10^{15} GeV]."
78
86
  }
79
87
  },
80
88
  "hydro": {
@@ -91,6 +99,10 @@
91
99
  "particle_mass": {
92
100
  "default": 1.0,
93
101
  "description": "particle mass [M_sun]."
102
+ },
103
+ "accrete_gas": {
104
+ "default": false,
105
+ "description": "switch on to accrete gas."
94
106
  }
95
107
  },
96
108
  "cosmology": {
@@ -5,6 +5,6 @@ import jaxdecomp as jd
5
5
 
6
6
 
7
7
  def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
8
- Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
- V = jnp.real(jd.fft.pifft3d(Vhat))
8
+ V_hat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
+ V = jnp.real(jd.fft.pifft3d(V_hat))
10
10
  return V
@@ -18,7 +18,7 @@ def get_cic_indices_and_weights(pos, dx, resolution):
18
18
  return i, ip1, weight_i, weight_ip1
19
19
 
20
20
 
21
- def bin_particles(pos, dx, resolution, m_particle):
21
+ def bin_particles(pos, m_particles, dx, resolution, multiple_masses):
22
22
  """Bin the particles into the grid using cloud-in-cell weights."""
23
23
  nx = resolution
24
24
  n_particle = pos.shape[0]
@@ -27,6 +27,10 @@ def bin_particles(pos, dx, resolution, m_particle):
27
27
 
28
28
  def deposit_particle(s, rho):
29
29
  """Deposit the particle mass into the grid."""
30
+ if multiple_masses:
31
+ m_particle = m_particles[s]
32
+ else:
33
+ m_particle = m_particles
30
34
  fac = m_particle / (dx * dx * dx)
31
35
  rho = rho.at[i[s, 0], i[s, 1], i[s, 2]].add(
32
36
  w_i[s, 0] * w_i[s, 1] * w_i[s, 2] * fac
@@ -112,3 +116,102 @@ def particles_drift(pos, vel, dt, box_size):
112
116
  pos = jnp.mod(pos, jnp.array([box_size, box_size, box_size]))
113
117
 
114
118
  return pos
119
+
120
+
121
+ def particles_accrete_gas(mass, rho, pos, G, sound_speed, dx, dt):
122
+ """Accrete gas onto particles (Bondi)."""
123
+ n_particle = pos.shape[0]
124
+ resolution = rho.shape[0]
125
+ i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
126
+ d_mass = jnp.zeros_like(mass)
127
+ d_rho = jnp.zeros_like(rho)
128
+ lam = jnp.exp(1.5) / 4.0 # ≈ 1.12
129
+ vol = dx**3
130
+
131
+ def accrete(s, deltas):
132
+ """Deposit the particle mass into the grid."""
133
+ d_mass, d_rho = deltas
134
+ dM_fac = dt * 4.0 * jnp.pi * lam * (G * mass[s]) ** 2 / sound_speed**3
135
+ # dM = dM_fac * rho
136
+
137
+ dm = w_i[s, 0] * w_i[s, 1] * w_i[s, 2] * dM_fac * rho[i[s, 0], i[s, 1], i[s, 2]]
138
+ d_rho = d_rho.at[i[s, 0], i[s, 1], i[s, 2]].add(-dm / vol)
139
+ d_mass = d_mass.at[s].add(dm)
140
+
141
+ dm = (
142
+ w_ip1[s, 0]
143
+ * w_i[s, 1]
144
+ * w_i[s, 2]
145
+ * dM_fac
146
+ * rho[ip1[s, 0], i[s, 1], i[s, 2]]
147
+ )
148
+ d_rho = d_rho.at[ip1[s, 0], i[s, 1], i[s, 2]].add(-dm / vol)
149
+ d_mass = d_mass.at[s].add(dm)
150
+
151
+ dm = (
152
+ w_i[s, 0]
153
+ * w_ip1[s, 1]
154
+ * w_i[s, 2]
155
+ * dM_fac
156
+ * rho[i[s, 0], ip1[s, 1], i[s, 2]]
157
+ )
158
+ d_rho = d_rho.at[i[s, 0], ip1[s, 1], i[s, 2]].add(-dm / vol)
159
+ d_mass = d_mass.at[s].add(dm)
160
+
161
+ dm = (
162
+ w_i[s, 0]
163
+ * w_i[s, 1]
164
+ * w_ip1[s, 2]
165
+ * dM_fac
166
+ * rho[i[s, 0], i[s, 1], ip1[s, 2]]
167
+ )
168
+ d_rho = d_rho.at[i[s, 0], i[s, 1], ip1[s, 2]].add(-dm / vol)
169
+ d_mass = d_mass.at[s].add(dm)
170
+
171
+ dm = (
172
+ w_ip1[s, 0]
173
+ * w_ip1[s, 1]
174
+ * w_i[s, 2]
175
+ * dM_fac
176
+ * rho[ip1[s, 0], ip1[s, 1], i[s, 2]]
177
+ )
178
+ d_rho = d_rho.at[ip1[s, 0], ip1[s, 1], i[s, 2]].add(-dm / vol)
179
+ d_mass = d_mass.at[s].add(dm)
180
+
181
+ dm = (
182
+ w_ip1[s, 0]
183
+ * w_i[s, 1]
184
+ * w_ip1[s, 2]
185
+ * dM_fac
186
+ * rho[ip1[s, 0], i[s, 1], ip1[s, 2]]
187
+ )
188
+ d_rho = d_rho.at[ip1[s, 0], i[s, 1], ip1[s, 2]].add(-dm / vol)
189
+ d_mass = d_mass.at[s].add(dm)
190
+
191
+ dm = (
192
+ w_i[s, 0]
193
+ * w_ip1[s, 1]
194
+ * w_ip1[s, 2]
195
+ * dM_fac
196
+ * rho[i[s, 0], ip1[s, 1], ip1[s, 2]]
197
+ )
198
+ d_rho = d_rho.at[i[s, 0], ip1[s, 1], ip1[s, 2]].add(-dm / vol)
199
+ d_mass = d_mass.at[s].add(dm)
200
+
201
+ dm = (
202
+ w_ip1[s, 0]
203
+ * w_ip1[s, 1]
204
+ * w_ip1[s, 2]
205
+ * dM_fac
206
+ * rho[ip1[s, 0], ip1[s, 1], ip1[s, 2]]
207
+ )
208
+ d_rho = d_rho.at[ip1[s, 0], ip1[s, 1], ip1[s, 2]].add(-dm / vol)
209
+ d_mass = d_mass.at[s].add(dm)
210
+
211
+ return d_mass, d_rho
212
+
213
+ d_mass, d_rho = jax.lax.fori_loop(0, n_particle, accrete, (d_mass, d_rho))
214
+ mass = mass + d_mass
215
+ rho = rho + d_rho
216
+
217
+ return mass, rho
@@ -9,7 +9,12 @@ from .constants import constants
9
9
  from .quantum import quantum_kick, quantum_drift, quantum_velocity
10
10
  from .gravity import calculate_gravitational_potential
11
11
  from .hydro import hydro_fluxes, hydro_accelerate
12
- from .particles import particles_accelerate, particles_drift, bin_particles
12
+ from .particles import (
13
+ particles_accelerate,
14
+ particles_drift,
15
+ particles_accrete_gas,
16
+ bin_particles,
17
+ )
13
18
  from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
14
19
  from .utils import (
15
20
  set_up_parameters,
@@ -62,9 +67,10 @@ class Simulation:
62
67
  self.params["physics"]["hydro"]
63
68
  or self.params["physics"]["particles"]
64
69
  or self.params["physics"]["external_potential"]
70
+ or self.params["quantum"]["f_15"] != 0.0
65
71
  ):
66
72
  raise NotImplementedError(
67
- "Cosmological hydro/particles/external_potential physics is not yet implemented."
73
+ "Cosmological hydro/particles/external_potential/SI is not yet implemented."
68
74
  )
69
75
 
70
76
  if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
@@ -94,6 +100,12 @@ class Simulation:
94
100
  xones, static_argnums=0, in_shardings=None, out_shardings=sharding
95
101
  )
96
102
 
103
+ # custom functions
104
+ self.custom_kick = None
105
+ self.custom_drift = None
106
+ self.custom_density = None
107
+ self.custom_plot = None
108
+
97
109
  # simulation state
98
110
  self.state = {}
99
111
  self.state["t"] = 0.0
@@ -119,6 +131,11 @@ class Simulation:
119
131
  if self.params["physics"]["particles"]:
120
132
  self.state["pos"] = jnp.zeros((self.num_particles, 3))
121
133
  self.state["vel"] = jnp.zeros((self.num_particles, 3))
134
+ if self.params["particles"]["accrete_gas"]:
135
+ self.state["mass"] = (
136
+ jnp.zeros(self.num_particles)
137
+ + self.params["particles"]["particle_mass"]
138
+ )
122
139
 
123
140
  if load_from_checkpoint:
124
141
  options = ocp.CheckpointManagerOptions()
@@ -173,6 +190,23 @@ class Simulation:
173
190
  / constants["speed_of_light"] ** 2
174
191
  )
175
192
 
193
+ @property
194
+ def scattering_length(self):
195
+ """
196
+ Return the axion self-interaction scattering length in the simulation (kpc)
197
+ """
198
+ f_15 = self.params["quantum"]["f_15"]
199
+ if f_15 == 0.0:
200
+ return 0.0
201
+ else:
202
+ f = f_15 * 1.0e24 * constants["electron_volt"]
203
+ sign = 1.0 if f > 0 else -1.0
204
+ hbar = constants["reduced_planck_constant"]
205
+ c = constants["speed_of_light"]
206
+ m = self.axion_mass
207
+ a_s = sign * (hbar * c**3 * m) / (32.0 * jnp.pi * (f**2))
208
+ return a_s
209
+
176
210
  @property
177
211
  def m_per_hbar(self):
178
212
  """
@@ -224,6 +258,13 @@ class Simulation:
224
258
  """
225
259
  return quantum_velocity(self.state["psi"], self.box_size, self.m_per_hbar)
226
260
 
261
+ @property
262
+ def rho_bar(self):
263
+ """
264
+ Return the mean density of the simulation
265
+ """
266
+ return self._calc_rho_bar(self.state)
267
+
227
268
  def _calc_rho_bar(self, state):
228
269
  rho_bar = 0.0
229
270
  if self.params["physics"]["quantum"]:
@@ -231,15 +272,19 @@ class Simulation:
231
272
  if self.params["physics"]["hydro"]:
232
273
  rho_bar += jnp.mean(state["rho"])
233
274
  if self.params["physics"]["particles"]:
234
- m_particle = self.params["particles"]["particle_mass"]
235
- n_particles = self.num_particles
236
275
  box_size = self.box_size
237
- rho_bar += m_particle * n_particles / box_size
276
+ if self.params["particles"]["accrete_gas"]:
277
+ rho_bar += jnp.sum(state["mass"]) / box_size**3
278
+ else:
279
+ m_particle = self.params["particles"]["particle_mass"]
280
+ n_particles = self.num_particles
281
+ rho_bar += m_particle * n_particles / box_size**3
282
+ if self.custom_density is not None:
283
+ rho_bar += jnp.mean(self.custom_density(state))
238
284
  return rho_bar
239
285
 
240
286
  def _calc_grav_potential(self, state, k_sq):
241
287
  G = constants["gravitational_constant"]
242
- m_particle = self.params["particles"]["particle_mass"]
243
288
  rho_bar = self._calc_rho_bar(state)
244
289
  rho_tot = 0.0
245
290
  if self.params["physics"]["quantum"]:
@@ -247,7 +292,19 @@ class Simulation:
247
292
  if self.params["physics"]["hydro"]:
248
293
  rho_tot += state["rho"]
249
294
  if self.params["physics"]["particles"]:
250
- rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
295
+ multiple_masses = self.params["particles"]["accrete_gas"]
296
+ if multiple_masses:
297
+ m_particles = state["mass"]
298
+ rho_tot += bin_particles(
299
+ state["pos"], m_particles, self.dx, self.resolution, multiple_masses
300
+ )
301
+ else:
302
+ m_particles = self.params["particles"]["particle_mass"]
303
+ rho_tot += bin_particles(
304
+ state["pos"], m_particles, self.dx, self.resolution, multiple_masses
305
+ )
306
+ if self.custom_density is not None:
307
+ rho_tot += self.custom_density(state)
251
308
  if self.params["physics"]["cosmology"]:
252
309
  scale_factor = 1.0 / (1.0 + state["redshift"])
253
310
  rho_bar *= scale_factor
@@ -280,18 +337,32 @@ class Simulation:
280
337
 
281
338
  # Simulation parameters
282
339
  dx = self.dx
283
- m_per_hbar = self.m_per_hbar
284
340
  box_size = self.box_size
341
+ num_cells = self.resolution**3
342
+ m_per_hbar = self.m_per_hbar
285
343
 
286
- dt_fac = 1.0
287
- dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
344
+ safety = self.params["time"]["safety_factor"]
345
+ dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
288
346
  t_start = self.params["time"]["start"]
289
347
  t_end = self.params["time"]["end"]
290
348
  t_span = t_end - t_start
291
349
  state["t"] = t_start
292
350
 
351
+ use_quantum = self.params["physics"]["quantum"]
352
+ use_gravity = self.params["physics"]["gravity"]
353
+ use_hydro = self.params["physics"]["hydro"]
354
+ use_particles = self.params["physics"]["particles"]
355
+ use_cosmology = self.params["physics"]["cosmology"]
356
+ use_external_potential = self.params["physics"]["external_potential"]
357
+ accrete_gas = self.params["particles"]["accrete_gas"]
358
+ save = self.params["output"]["save"]
359
+ use_custom = self.custom_kick is not None or self.custom_drift is not None
360
+ if use_custom:
361
+ custom_kick = self.custom_kick
362
+ custom_drift = self.custom_drift
363
+
293
364
  # cosmology
294
- if self.params["physics"]["cosmology"]:
365
+ if use_cosmology:
295
366
  z_start = self.params["time"]["start"]
296
367
  z_end = self.params["time"]["end"]
297
368
  omega_matter = self.params["cosmology"]["omega_matter"]
@@ -303,6 +374,17 @@ class Simulation:
303
374
  state["t"] = 0.0
304
375
  state["redshift"] = z_start
305
376
 
377
+ # self-interaction
378
+ a_s = self.scattering_length
379
+ c = constants["speed_of_light"]
380
+ m = self.axion_mass
381
+ si_coeff = None
382
+ si_coeff2 = None
383
+ do_self_interaction = a_s != 0.0
384
+ if do_self_interaction:
385
+ si_coeff = (4.0 * jnp.pi) * (a_s / m) / m_per_hbar**2
386
+ si_coeff2 = (32.0 * jnp.pi**2 / 3.0) * (a_s / (m * c)) ** 2 / m_per_hbar**5
387
+
306
388
  # hydro
307
389
  c_sound = self.params["hydro"]["sound_speed"]
308
390
 
@@ -317,16 +399,14 @@ class Simulation:
317
399
  k_sq = None
318
400
 
319
401
  # Fourier space variables
320
- if self.params["physics"]["gravity"] or self.params["physics"]["quantum"]:
402
+ if use_gravity or use_quantum:
321
403
  kx, ky, kz = self.kgrid
322
404
  k_sq = kx**2 + ky**2 + kz**2
323
405
 
324
406
  # Checkpointer
325
- if self.params["output"]["save"]:
407
+ if save:
326
408
  options = ocp.CheckpointManagerOptions()
327
- checkpoint_dir = checkpoint_dir = os.path.join(
328
- os.getcwd(), self.params["output"]["path"]
329
- )
409
+ checkpoint_dir = os.path.join(os.getcwd(), self.params["output"]["path"])
330
410
  path = os.path.join(os.getcwd(), checkpoint_dir)
331
411
  if jax.process_index() == 0:
332
412
  path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
@@ -336,80 +416,91 @@ class Simulation:
336
416
 
337
417
  def _kick(state, kx, ky, kz, k_sq, dt):
338
418
  # Kick (half-step)
339
- if (
340
- self.params["physics"]["gravity"]
341
- and self.params["physics"]["external_potential"]
342
- ):
419
+ if use_gravity and use_external_potential:
343
420
  V = self._calc_grav_potential(state, k_sq) + state["V_ext"]
344
- elif self.params["physics"]["gravity"]:
421
+ elif use_gravity:
345
422
  V = self._calc_grav_potential(state, k_sq)
346
- elif self.params["physics"]["external_potential"]:
423
+ elif use_external_potential:
347
424
  V = state["V_ext"]
348
425
 
349
- if (
350
- self.params["physics"]["gravity"]
351
- or self.params["physics"]["external_potential"]
352
- ):
353
- if self.params["physics"]["quantum"]:
354
- state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
355
- if self.params["physics"]["hydro"]:
426
+ if use_gravity or use_external_potential:
427
+ if use_quantum:
428
+ if do_self_interaction:
429
+ rho = jnp.abs(state["psi"]) ** 2
430
+ V_prime = V + si_coeff * rho + si_coeff2 * rho**2
431
+ state["psi"] = quantum_kick(
432
+ state["psi"], V_prime, m_per_hbar, dt
433
+ )
434
+ else:
435
+ state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
436
+ if use_hydro:
356
437
  state["vx"], state["vy"], state["vz"] = hydro_accelerate(
357
438
  state["vx"], state["vy"], state["vz"], V, kx, ky, kz, dt
358
439
  )
359
- if self.params["physics"]["particles"]:
440
+ if use_particles:
360
441
  state["vel"] = particles_accelerate(
361
442
  state["vel"], state["pos"], V, kx, ky, kz, dx, dt
362
443
  )
444
+ if use_custom:
445
+ state = custom_kick(state, V, dt)
446
+
447
+ return state
363
448
 
364
449
  def _drift(state, k_sq, dt):
365
450
  # Drift (full-step)
366
- if self.params["physics"]["quantum"]:
451
+ if use_quantum:
367
452
  state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
368
- if self.params["physics"]["hydro"]:
453
+ if use_hydro:
369
454
  state["rho"], state["vx"], state["vy"], state["vz"] = hydro_fluxes(
370
455
  state["rho"], state["vx"], state["vy"], state["vz"], dt, dx, c_sound
371
456
  )
372
- if self.params["physics"]["particles"]:
457
+ if use_particles:
373
458
  state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
459
+ if use_custom:
460
+ state = custom_drift(state, k_sq, dt)
461
+ if use_hydro and accrete_gas:
462
+ G = constants["gravitational_constant"]
463
+ state["mass"], state["rho"] = particles_accrete_gas(
464
+ state["mass"], state["rho"], state["pos"], G, c_sound, dx, dt
465
+ )
466
+
467
+ return state
374
468
 
375
- @jax.jit
376
469
  def _update(_, carry):
377
470
  # Update the simulation state by one timestep
378
471
  # according to a 2nd-order `kick-drift-kick` scheme
379
472
  state, kx, ky, kz, k_sq = carry
380
- _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
381
- _drift(state, k_sq, dt)
473
+ state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
474
+ state = _drift(state, k_sq, dt)
382
475
  # update time & redshift
383
476
  state["t"] += dt
384
- if self.params["physics"]["cosmology"]:
477
+ if use_cosmology:
385
478
  scale_factor = get_next_scale_factor(
386
- state["redshift"],
387
- dt,
388
- self.params["cosmology"]["omega_matter"],
389
- self.params["cosmology"]["omega_lambda"],
390
- self.params["cosmology"]["little_h"],
479
+ state["redshift"], dt, omega_matter, omega_lambda, little_h
391
480
  )
392
481
  state["redshift"] = 1.0 / scale_factor - 1.0
393
- _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
482
+ state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
394
483
 
395
484
  return state, kx, ky, kz, k_sq
396
485
 
397
486
  # save initial state
398
487
  if jax.process_index() == 0:
399
488
  print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
400
- if self.params["output"]["save"]:
489
+ if save:
401
490
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
402
491
  json.dump(self.params, f, indent=2)
403
492
  async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
404
493
  plot_sim(state, checkpoint_dir, 0, self.params)
494
+ if self.custom_plot is not None:
495
+ self.custom_plot(state, checkpoint_dir, 0, self.params)
405
496
  async_checkpoint_manager.wait_until_finished()
406
497
 
407
498
  # Simulation Main Loop
408
499
  t_start_timer = time.time()
409
- if self.params["output"]["save"]:
500
+ if save:
410
501
  for i in range(1, num_checkpoints + 1):
411
502
  carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
412
- state, kx, ky, kz, k_sq = carry
503
+ state, _, _, _, _ = carry
413
504
  jax.block_until_ready(state)
414
505
  # save state
415
506
  async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
@@ -417,17 +508,18 @@ class Simulation:
417
508
  elapsed = time.time() - t_start_timer
418
509
  est_total = elapsed / i * num_checkpoints
419
510
  est_remaining = est_total - elapsed
420
- num_cells = self.resolution**3
421
511
  mcups = (num_cells * (i * nt_sub)) / (elapsed * 1.0e6)
422
512
  if jax.process_index() == 0:
423
513
  print(
424
514
  f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
425
515
  )
426
516
  plot_sim(state, checkpoint_dir, i, self.params)
517
+ if self.custom_plot is not None:
518
+ self.custom_plot(state, checkpoint_dir, i, self.params)
427
519
  async_checkpoint_manager.wait_until_finished()
428
520
  else:
429
521
  carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
430
- state, kx, ky, kz, k_sq = carry
522
+ state, _, _, _, _ = carry
431
523
  jax.block_until_ready(state)
432
524
  if jax.process_index() == 0:
433
525
  print("Simulation Run Time (s): ", time.time() - t_start_timer)
@@ -31,7 +31,7 @@ def print_distributed_info():
31
31
 
32
32
  def set_up_parameters(user_overwrites):
33
33
  # first load the default params
34
- params_path = importlib.resources.files("jaxion") / "params_default.json"
34
+ params_path = importlib.resources.files("jaxion") / "defaults.json"
35
35
  with params_path.open("r", encoding="utf-8") as f:
36
36
  params = json.load(f)
37
37
 
@@ -82,6 +82,14 @@ def plot_sim(state, checkpoint_dir, i, params):
82
82
  vmax=vmax,
83
83
  extent=[0, nx, 0, nx],
84
84
  )
85
+ if params["physics"]["particles"]:
86
+ # draw particles
87
+ box_size = params["domain"]["box_size"]
88
+ sx = (state["pos"][:, 0] / box_size) * nx
89
+ sy = (state["pos"][:, 1] / box_size) * nx
90
+ plt.plot(
91
+ sx, sy, color="red", marker=".", linestyle="None", markersize=5
92
+ )
85
93
  ax.set_aspect("equal")
86
94
  ax.get_xaxis().set_visible(False)
87
95
  ax.get_yaxis().set_visible(False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -68,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
68
68
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
69
69
 
70
70
 
71
- ## Getting started
71
+ ## Install Jaxion
72
72
 
73
73
  Install with:
74
74
 
@@ -87,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
87
87
 
88
88
  ## Examples
89
89
 
90
- Check out the `examples/` directory for demonstrations of using Jaxion.
90
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
91
91
 
92
92
  <p align="center">
93
93
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -6,9 +6,9 @@ jaxion/__init__.py
6
6
  jaxion/analysis.py
7
7
  jaxion/constants.py
8
8
  jaxion/cosmology.py
9
+ jaxion/defaults.json
9
10
  jaxion/gravity.py
10
11
  jaxion/hydro.py
11
- jaxion/params_default.json
12
12
  jaxion/particles.py
13
13
  jaxion/quantum.py
14
14
  jaxion/simulation.py
@@ -10,7 +10,7 @@ license = "Apache-2.0"
10
10
  dynamic = ["version", "dependencies"]
11
11
 
12
12
  [tool.setuptools.packages.find]
13
- include = ["jaxion"]
13
+ include = ["jaxio*"]
14
14
  exclude = ["paper"]
15
15
 
16
16
  [tool.setuptools.package-data]
@@ -12,7 +12,7 @@ def test_tidal_stripping():
12
12
  )
13
13
  assert sim.resolution == 32
14
14
  assert sim.state["t"] > 0.0
15
- assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(639.0479, rel=rel_tol)
15
+ assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(162.028, rel=rel_tol)
16
16
 
17
17
 
18
18
  def test_tidal_stripping_distributed_emulate():
@@ -22,7 +22,7 @@ def test_tidal_stripping_distributed_emulate():
22
22
  )
23
23
  assert sim.resolution == 32
24
24
  assert sim.state["t"] > 0.0
25
- assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(639.0479, rel=rel_tol)
25
+ assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(162.028, rel=rel_tol)
26
26
 
27
27
 
28
28
  def test_heating_gas():
@@ -41,13 +41,13 @@ def test_heating_stars():
41
41
  )
42
42
  assert sim.resolution == 32
43
43
  assert sim.state["t"] > 0.0
44
- assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(2574.4248, rel=rel_tol)
44
+ assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(2574.395, rel=rel_tol)
45
45
  assert jnp.mean(jnp.abs(sim.state["vel"][:, 0])) == pytest.approx(
46
- 16.625286, rel=rel_tol
46
+ 16.625353, rel=rel_tol
47
47
  )
48
48
  assert jnp.mean(jnp.abs(sim.state["vel"][:, 1])) == pytest.approx(
49
- 17.345531, rel=rel_tol
49
+ 17.345486, rel=rel_tol
50
50
  )
51
51
  assert jnp.mean(jnp.abs(sim.state["vel"][:, 2])) == pytest.approx(
52
- 18.218365, rel=rel_tol
52
+ 18.218296, rel=rel_tol
53
53
  )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes