jaxion 0.0.6__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {jaxion-0.0.6 → jaxion-0.0.8}/PKG-INFO +1 -1
  2. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/defaults.json +8 -0
  3. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/particles.py +104 -1
  4. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/simulation.py +72 -16
  5. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/visualization.py +8 -0
  6. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion.egg-info/PKG-INFO +1 -1
  7. {jaxion-0.0.6 → jaxion-0.0.8}/tests/test_examples.py +4 -4
  8. {jaxion-0.0.6 → jaxion-0.0.8}/LICENSE +0 -0
  9. {jaxion-0.0.6 → jaxion-0.0.8}/README.md +0 -0
  10. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/__init__.py +0 -0
  11. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/analysis.py +0 -0
  12. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/constants.py +0 -0
  13. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/cosmology.py +0 -0
  14. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/gravity.py +0 -0
  15. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/hydro.py +0 -0
  16. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/quantum.py +0 -0
  17. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion/utils.py +0 -0
  18. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion.egg-info/SOURCES.txt +0 -0
  19. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion.egg-info/dependency_links.txt +0 -0
  20. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion.egg-info/requires.txt +0 -0
  21. {jaxion-0.0.6 → jaxion-0.0.8}/jaxion.egg-info/top_level.txt +0 -0
  22. {jaxion-0.0.6 → jaxion-0.0.8}/pyproject.toml +0 -0
  23. {jaxion-0.0.6 → jaxion-0.0.8}/requirements.txt +0 -0
  24. {jaxion-0.0.6 → jaxion-0.0.8}/setup.cfg +0 -0
  25. {jaxion-0.0.6 → jaxion-0.0.8}/tests/test_analysis.py +0 -0
  26. {jaxion-0.0.6 → jaxion-0.0.8}/tests/test_cosmology.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -48,6 +48,10 @@
48
48
  "default": 1.0,
49
49
  "description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
50
50
  },
51
+ "safety_factor": {
52
+ "default": 1.0,
53
+ "description": "safety factor for time stepping."
54
+ },
51
55
  "adaptive": {
52
56
  "default": false,
53
57
  "description": "switch on for adaptive time stepping."
@@ -95,6 +99,10 @@
95
99
  "particle_mass": {
96
100
  "default": 1.0,
97
101
  "description": "particle mass [M_sun]."
102
+ },
103
+ "accrete_gas": {
104
+ "default": false,
105
+ "description": "switch on to accrete gas."
98
106
  }
99
107
  },
100
108
  "cosmology": {
@@ -18,7 +18,7 @@ def get_cic_indices_and_weights(pos, dx, resolution):
18
18
  return i, ip1, weight_i, weight_ip1
19
19
 
20
20
 
21
- def bin_particles(pos, dx, resolution, m_particle):
21
+ def bin_particles(pos, m_particles, dx, resolution, multiple_masses):
22
22
  """Bin the particles into the grid using cloud-in-cell weights."""
23
23
  nx = resolution
24
24
  n_particle = pos.shape[0]
@@ -27,6 +27,10 @@ def bin_particles(pos, dx, resolution, m_particle):
27
27
 
28
28
  def deposit_particle(s, rho):
29
29
  """Deposit the particle mass into the grid."""
30
+ if multiple_masses:
31
+ m_particle = m_particles[s]
32
+ else:
33
+ m_particle = m_particles
30
34
  fac = m_particle / (dx * dx * dx)
31
35
  rho = rho.at[i[s, 0], i[s, 1], i[s, 2]].add(
32
36
  w_i[s, 0] * w_i[s, 1] * w_i[s, 2] * fac
@@ -112,3 +116,102 @@ def particles_drift(pos, vel, dt, box_size):
112
116
  pos = jnp.mod(pos, jnp.array([box_size, box_size, box_size]))
113
117
 
114
118
  return pos
119
+
120
+
121
+ def particles_accrete_gas(mass, rho, pos, G, sound_speed, dx, dt):
122
+ """Accrete gas onto particles (Bondi)."""
123
+ n_particle = pos.shape[0]
124
+ resolution = rho.shape[0]
125
+ i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
126
+ d_mass = jnp.zeros_like(mass)
127
+ d_rho = jnp.zeros_like(rho)
128
+ lam = jnp.exp(1.5) / 4.0 # ≈ 1.12
129
+ vol = dx**3
130
+
131
+ def accrete(s, deltas):
132
+ """Deposit the particle mass into the grid."""
133
+ d_mass, d_rho = deltas
134
+ dM_fac = dt * 4.0 * jnp.pi * lam * (G * mass[s]) ** 2 / sound_speed**3
135
+ # dM = dM_fac * rho
136
+
137
+ dm = w_i[s, 0] * w_i[s, 1] * w_i[s, 2] * dM_fac * rho[i[s, 0], i[s, 1], i[s, 2]]
138
+ d_rho = d_rho.at[i[s, 0], i[s, 1], i[s, 2]].add(-dm / vol)
139
+ d_mass = d_mass.at[s].add(dm)
140
+
141
+ dm = (
142
+ w_ip1[s, 0]
143
+ * w_i[s, 1]
144
+ * w_i[s, 2]
145
+ * dM_fac
146
+ * rho[ip1[s, 0], i[s, 1], i[s, 2]]
147
+ )
148
+ d_rho = d_rho.at[ip1[s, 0], i[s, 1], i[s, 2]].add(-dm / vol)
149
+ d_mass = d_mass.at[s].add(dm)
150
+
151
+ dm = (
152
+ w_i[s, 0]
153
+ * w_ip1[s, 1]
154
+ * w_i[s, 2]
155
+ * dM_fac
156
+ * rho[i[s, 0], ip1[s, 1], i[s, 2]]
157
+ )
158
+ d_rho = d_rho.at[i[s, 0], ip1[s, 1], i[s, 2]].add(-dm / vol)
159
+ d_mass = d_mass.at[s].add(dm)
160
+
161
+ dm = (
162
+ w_i[s, 0]
163
+ * w_i[s, 1]
164
+ * w_ip1[s, 2]
165
+ * dM_fac
166
+ * rho[i[s, 0], i[s, 1], ip1[s, 2]]
167
+ )
168
+ d_rho = d_rho.at[i[s, 0], i[s, 1], ip1[s, 2]].add(-dm / vol)
169
+ d_mass = d_mass.at[s].add(dm)
170
+
171
+ dm = (
172
+ w_ip1[s, 0]
173
+ * w_ip1[s, 1]
174
+ * w_i[s, 2]
175
+ * dM_fac
176
+ * rho[ip1[s, 0], ip1[s, 1], i[s, 2]]
177
+ )
178
+ d_rho = d_rho.at[ip1[s, 0], ip1[s, 1], i[s, 2]].add(-dm / vol)
179
+ d_mass = d_mass.at[s].add(dm)
180
+
181
+ dm = (
182
+ w_ip1[s, 0]
183
+ * w_i[s, 1]
184
+ * w_ip1[s, 2]
185
+ * dM_fac
186
+ * rho[ip1[s, 0], i[s, 1], ip1[s, 2]]
187
+ )
188
+ d_rho = d_rho.at[ip1[s, 0], i[s, 1], ip1[s, 2]].add(-dm / vol)
189
+ d_mass = d_mass.at[s].add(dm)
190
+
191
+ dm = (
192
+ w_i[s, 0]
193
+ * w_ip1[s, 1]
194
+ * w_ip1[s, 2]
195
+ * dM_fac
196
+ * rho[i[s, 0], ip1[s, 1], ip1[s, 2]]
197
+ )
198
+ d_rho = d_rho.at[i[s, 0], ip1[s, 1], ip1[s, 2]].add(-dm / vol)
199
+ d_mass = d_mass.at[s].add(dm)
200
+
201
+ dm = (
202
+ w_ip1[s, 0]
203
+ * w_ip1[s, 1]
204
+ * w_ip1[s, 2]
205
+ * dM_fac
206
+ * rho[ip1[s, 0], ip1[s, 1], ip1[s, 2]]
207
+ )
208
+ d_rho = d_rho.at[ip1[s, 0], ip1[s, 1], ip1[s, 2]].add(-dm / vol)
209
+ d_mass = d_mass.at[s].add(dm)
210
+
211
+ return d_mass, d_rho
212
+
213
+ d_mass, d_rho = jax.lax.fori_loop(0, n_particle, accrete, (d_mass, d_rho))
214
+ mass = mass + d_mass
215
+ rho = rho + d_rho
216
+
217
+ return mass, rho
@@ -9,7 +9,12 @@ from .constants import constants
9
9
  from .quantum import quantum_kick, quantum_drift, quantum_velocity
10
10
  from .gravity import calculate_gravitational_potential
11
11
  from .hydro import hydro_fluxes, hydro_accelerate
12
- from .particles import particles_accelerate, particles_drift, bin_particles
12
+ from .particles import (
13
+ particles_accelerate,
14
+ particles_drift,
15
+ particles_accrete_gas,
16
+ bin_particles,
17
+ )
13
18
  from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
14
19
  from .utils import (
15
20
  set_up_parameters,
@@ -37,7 +42,7 @@ class Simulation:
37
42
 
38
43
  """
39
44
 
40
- def __init__(self, params, sharding=None):
45
+ def __init__(self, params, sharding=None, checkpoint_number=None):
41
46
  # allow loading directly from a checkpoint path
42
47
  load_from_checkpoint = False
43
48
  checkpoint_dir = ""
@@ -95,11 +100,12 @@ class Simulation:
95
100
  xones, static_argnums=0, in_shardings=None, out_shardings=sharding
96
101
  )
97
102
 
98
- # customfunctions
103
+ # custom functions
99
104
  self.custom_kick = None
100
105
  self.custom_drift = None
101
106
  self.custom_density = None
102
107
  self.custom_plot = None
108
+ self.callback = None
103
109
 
104
110
  # simulation state
105
111
  self.state = {}
@@ -126,13 +132,21 @@ class Simulation:
126
132
  if self.params["physics"]["particles"]:
127
133
  self.state["pos"] = jnp.zeros((self.num_particles, 3))
128
134
  self.state["vel"] = jnp.zeros((self.num_particles, 3))
135
+ if self.params["particles"]["accrete_gas"]:
136
+ self.state["mass"] = (
137
+ jnp.zeros(self.num_particles)
138
+ + self.params["particles"]["particle_mass"]
139
+ )
129
140
 
130
141
  if load_from_checkpoint:
131
142
  options = ocp.CheckpointManagerOptions()
132
143
  async_checkpoint_manager = ocp.CheckpointManager(
133
144
  checkpoint_dir, options=options
134
145
  )
135
- step = async_checkpoint_manager.latest_step()
146
+ if checkpoint_number is not None:
147
+ step = checkpoint_number
148
+ else:
149
+ step = async_checkpoint_manager.latest_step()
136
150
  self.state = async_checkpoint_manager.restore(
137
151
  step, args=ocp.args.StandardRestore(self.state)
138
152
  )
@@ -168,6 +182,23 @@ class Simulation:
168
182
  """
169
183
  return self.box_size / self.resolution
170
184
 
185
+ @property
186
+ def nt(self):
187
+ """
188
+ Return the number of timesteps
189
+ """
190
+ dx = self.dx
191
+ m_per_hbar = self.m_per_hbar
192
+ safety = self.params["time"]["safety_factor"]
193
+ dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
194
+ t_start = self.params["time"]["start"]
195
+ t_end = self.params["time"]["end"]
196
+ t_span = t_end - t_start
197
+ num_checkpoints = self.params["output"]["num_checkpoints"]
198
+ nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
199
+
200
+ return nt
201
+
171
202
  @property
172
203
  def axion_mass(self):
173
204
  """
@@ -262,17 +293,19 @@ class Simulation:
262
293
  if self.params["physics"]["hydro"]:
263
294
  rho_bar += jnp.mean(state["rho"])
264
295
  if self.params["physics"]["particles"]:
265
- m_particle = self.params["particles"]["particle_mass"]
266
- n_particles = self.num_particles
267
296
  box_size = self.box_size
268
- rho_bar += m_particle * n_particles / box_size
297
+ if self.params["particles"]["accrete_gas"]:
298
+ rho_bar += jnp.sum(state["mass"]) / box_size**3
299
+ else:
300
+ m_particle = self.params["particles"]["particle_mass"]
301
+ n_particles = self.num_particles
302
+ rho_bar += m_particle * n_particles / box_size**3
269
303
  if self.custom_density is not None:
270
304
  rho_bar += jnp.mean(self.custom_density(state))
271
305
  return rho_bar
272
306
 
273
307
  def _calc_grav_potential(self, state, k_sq):
274
308
  G = constants["gravitational_constant"]
275
- m_particle = self.params["particles"]["particle_mass"]
276
309
  rho_bar = self._calc_rho_bar(state)
277
310
  rho_tot = 0.0
278
311
  if self.params["physics"]["quantum"]:
@@ -280,7 +313,17 @@ class Simulation:
280
313
  if self.params["physics"]["hydro"]:
281
314
  rho_tot += state["rho"]
282
315
  if self.params["physics"]["particles"]:
283
- rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
316
+ multiple_masses = self.params["particles"]["accrete_gas"]
317
+ if multiple_masses:
318
+ m_particles = state["mass"]
319
+ rho_tot += bin_particles(
320
+ state["pos"], m_particles, self.dx, self.resolution, multiple_masses
321
+ )
322
+ else:
323
+ m_particles = self.params["particles"]["particle_mass"]
324
+ rho_tot += bin_particles(
325
+ state["pos"], m_particles, self.dx, self.resolution, multiple_masses
326
+ )
284
327
  if self.custom_density is not None:
285
328
  rho_tot += self.custom_density(state)
286
329
  if self.params["physics"]["cosmology"]:
@@ -319,8 +362,6 @@ class Simulation:
319
362
  num_cells = self.resolution**3
320
363
  m_per_hbar = self.m_per_hbar
321
364
 
322
- dt_fac = 1.0
323
- dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
324
365
  t_start = self.params["time"]["start"]
325
366
  t_end = self.params["time"]["end"]
326
367
  t_span = t_end - t_start
@@ -332,11 +373,15 @@ class Simulation:
332
373
  use_particles = self.params["physics"]["particles"]
333
374
  use_cosmology = self.params["physics"]["cosmology"]
334
375
  use_external_potential = self.params["physics"]["external_potential"]
376
+ accrete_gas = self.params["particles"]["accrete_gas"]
335
377
  save = self.params["output"]["save"]
336
378
  use_custom = self.custom_kick is not None or self.custom_drift is not None
337
379
  if use_custom:
338
380
  custom_kick = self.custom_kick
339
381
  custom_drift = self.custom_drift
382
+ use_callback = self.callback is not None
383
+ if use_callback:
384
+ custom_callback = self.callback
340
385
 
341
386
  # cosmology
342
387
  if use_cosmology:
@@ -367,7 +412,7 @@ class Simulation:
367
412
 
368
413
  # round up to the nearest multiple of num_checkpoints
369
414
  num_checkpoints = self.params["output"]["num_checkpoints"]
370
- nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
415
+ nt = self.nt
371
416
  nt_sub = int(round(nt / num_checkpoints))
372
417
  dt = t_span / nt
373
418
 
@@ -435,10 +480,15 @@ class Simulation:
435
480
  state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
436
481
  if use_custom:
437
482
  state = custom_drift(state, k_sq, dt)
483
+ if use_hydro and accrete_gas:
484
+ G = constants["gravitational_constant"]
485
+ state["mass"], state["rho"] = particles_accrete_gas(
486
+ state["mass"], state["rho"], state["pos"], G, c_sound, dx, dt
487
+ )
438
488
 
439
489
  return state
440
490
 
441
- def _update(_, carry):
491
+ def _update(i, carry):
442
492
  # Update the simulation state by one timestep
443
493
  # according to a 2nd-order `kick-drift-kick` scheme
444
494
  state, kx, ky, kz, k_sq = carry
@@ -453,6 +503,9 @@ class Simulation:
453
503
  state["redshift"] = 1.0 / scale_factor - 1.0
454
504
  state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
455
505
 
506
+ if use_callback:
507
+ custom_callback(i, state)
508
+
456
509
  return state, kx, ky, kz, k_sq
457
510
 
458
511
  # save initial state
@@ -462,16 +515,19 @@ class Simulation:
462
515
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
463
516
  json.dump(self.params, f, indent=2)
464
517
  async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
518
+ async_checkpoint_manager.wait_until_finished()
465
519
  plot_sim(state, checkpoint_dir, 0, self.params)
466
520
  if self.custom_plot is not None:
467
521
  self.custom_plot(state, checkpoint_dir, 0, self.params)
468
- async_checkpoint_manager.wait_until_finished()
469
522
 
470
523
  # Simulation Main Loop
471
524
  t_start_timer = time.time()
472
525
  if save:
473
526
  for i in range(1, num_checkpoints + 1):
474
- carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
527
+ i_timestep = (i - 1) * nt_sub
528
+ carry = jax.lax.fori_loop(
529
+ i_timestep, i_timestep + nt_sub, _update, init_val=carry
530
+ )
475
531
  state, _, _, _, _ = carry
476
532
  jax.block_until_ready(state)
477
533
  # save state
@@ -485,10 +541,10 @@ class Simulation:
485
541
  print(
486
542
  f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
487
543
  )
544
+ async_checkpoint_manager.wait_until_finished()
488
545
  plot_sim(state, checkpoint_dir, i, self.params)
489
546
  if self.custom_plot is not None:
490
547
  self.custom_plot(state, checkpoint_dir, i, self.params)
491
- async_checkpoint_manager.wait_until_finished()
492
548
  else:
493
549
  carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
494
550
  state, _, _, _, _ = carry
@@ -82,6 +82,14 @@ def plot_sim(state, checkpoint_dir, i, params):
82
82
  vmax=vmax,
83
83
  extent=[0, nx, 0, nx],
84
84
  )
85
+ if params["physics"]["particles"]:
86
+ # draw particles
87
+ box_size = params["domain"]["box_size"]
88
+ sx = (state["pos"][:, 0] / box_size) * nx
89
+ sy = (state["pos"][:, 1] / box_size) * nx
90
+ plt.plot(
91
+ sx, sy, color="red", marker=".", linestyle="None", markersize=5
92
+ )
85
93
  ax.set_aspect("equal")
86
94
  ax.get_xaxis().set_visible(False)
87
95
  ax.get_yaxis().set_visible(False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -41,13 +41,13 @@ def test_heating_stars():
41
41
  )
42
42
  assert sim.resolution == 32
43
43
  assert sim.state["t"] > 0.0
44
- assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(2574.4248, rel=rel_tol)
44
+ assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(2574.395, rel=rel_tol)
45
45
  assert jnp.mean(jnp.abs(sim.state["vel"][:, 0])) == pytest.approx(
46
- 16.625286, rel=rel_tol
46
+ 16.625353, rel=rel_tol
47
47
  )
48
48
  assert jnp.mean(jnp.abs(sim.state["vel"][:, 1])) == pytest.approx(
49
- 17.345531, rel=rel_tol
49
+ 17.345486, rel=rel_tol
50
50
  )
51
51
  assert jnp.mean(jnp.abs(sim.state["vel"][:, 2])) == pytest.approx(
52
- 18.218365, rel=rel_tol
52
+ 18.218296, rel=rel_tol
53
53
  )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes