jaxion 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -75,6 +75,10 @@
75
75
  "m_22": {
76
76
  "default": 1.0,
77
77
  "description": "axion mass [10^{-22} eV]."
78
+ },
79
+ "f_15": {
80
+ "default": 0.0,
81
+ "description": "self-interaction strength [10^{15} GeV]."
78
82
  }
79
83
  },
80
84
  "hydro": {
jaxion/gravity.py CHANGED
@@ -5,6 +5,6 @@ import jaxdecomp as jd
5
5
 
6
6
 
7
7
  def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
8
- Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
- V = jnp.real(jd.fft.pifft3d(Vhat))
8
+ V_hat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
+ V = jnp.real(jd.fft.pifft3d(V_hat))
10
10
  return V
jaxion/simulation.py CHANGED
@@ -62,9 +62,10 @@ class Simulation:
62
62
  self.params["physics"]["hydro"]
63
63
  or self.params["physics"]["particles"]
64
64
  or self.params["physics"]["external_potential"]
65
+ or self.params["quantum"]["f_15"] != 0.0
65
66
  ):
66
67
  raise NotImplementedError(
67
- "Cosmological hydro/particles/external_potential physics is not yet implemented."
68
+ "Cosmological hydro/particles/external_potential/SI is not yet implemented."
68
69
  )
69
70
 
70
71
  if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
@@ -94,6 +95,12 @@ class Simulation:
94
95
  xones, static_argnums=0, in_shardings=None, out_shardings=sharding
95
96
  )
96
97
 
98
+ # customfunctions
99
+ self.custom_kick = None
100
+ self.custom_drift = None
101
+ self.custom_density = None
102
+ self.custom_plot = None
103
+
97
104
  # simulation state
98
105
  self.state = {}
99
106
  self.state["t"] = 0.0
@@ -173,6 +180,23 @@ class Simulation:
173
180
  / constants["speed_of_light"] ** 2
174
181
  )
175
182
 
183
+ @property
184
+ def scattering_length(self):
185
+ """
186
+ Return the axion self-interaction scattering length in the simulation (kpc)
187
+ """
188
+ f_15 = self.params["quantum"]["f_15"]
189
+ if f_15 == 0.0:
190
+ return 0.0
191
+ else:
192
+ f = f_15 * 1.0e24 * constants["electron_volt"]
193
+ sign = 1.0 if f > 0 else -1.0
194
+ hbar = constants["reduced_planck_constant"]
195
+ c = constants["speed_of_light"]
196
+ m = self.axion_mass
197
+ a_s = sign * (hbar * c**3 * m) / (32.0 * jnp.pi * (f**2))
198
+ return a_s
199
+
176
200
  @property
177
201
  def m_per_hbar(self):
178
202
  """
@@ -224,6 +248,13 @@ class Simulation:
224
248
  """
225
249
  return quantum_velocity(self.state["psi"], self.box_size, self.m_per_hbar)
226
250
 
251
+ @property
252
+ def rho_bar(self):
253
+ """
254
+ Return the mean density of the simulation
255
+ """
256
+ return self._calc_rho_bar(self.state)
257
+
227
258
  def _calc_rho_bar(self, state):
228
259
  rho_bar = 0.0
229
260
  if self.params["physics"]["quantum"]:
@@ -235,6 +266,8 @@ class Simulation:
235
266
  n_particles = self.num_particles
236
267
  box_size = self.box_size
237
268
  rho_bar += m_particle * n_particles / box_size
269
+ if self.custom_density is not None:
270
+ rho_bar += jnp.mean(self.custom_density(state))
238
271
  return rho_bar
239
272
 
240
273
  def _calc_grav_potential(self, state, k_sq):
@@ -248,6 +281,8 @@ class Simulation:
248
281
  rho_tot += state["rho"]
249
282
  if self.params["physics"]["particles"]:
250
283
  rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
284
+ if self.custom_density is not None:
285
+ rho_tot += self.custom_density(state)
251
286
  if self.params["physics"]["cosmology"]:
252
287
  scale_factor = 1.0 / (1.0 + state["redshift"])
253
288
  rho_bar *= scale_factor
@@ -280,8 +315,9 @@ class Simulation:
280
315
 
281
316
  # Simulation parameters
282
317
  dx = self.dx
283
- m_per_hbar = self.m_per_hbar
284
318
  box_size = self.box_size
319
+ num_cells = self.resolution**3
320
+ m_per_hbar = self.m_per_hbar
285
321
 
286
322
  dt_fac = 1.0
287
323
  dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
@@ -290,8 +326,20 @@ class Simulation:
290
326
  t_span = t_end - t_start
291
327
  state["t"] = t_start
292
328
 
329
+ use_quantum = self.params["physics"]["quantum"]
330
+ use_gravity = self.params["physics"]["gravity"]
331
+ use_hydro = self.params["physics"]["hydro"]
332
+ use_particles = self.params["physics"]["particles"]
333
+ use_cosmology = self.params["physics"]["cosmology"]
334
+ use_external_potential = self.params["physics"]["external_potential"]
335
+ save = self.params["output"]["save"]
336
+ use_custom = self.custom_kick is not None or self.custom_drift is not None
337
+ if use_custom:
338
+ custom_kick = self.custom_kick
339
+ custom_drift = self.custom_drift
340
+
293
341
  # cosmology
294
- if self.params["physics"]["cosmology"]:
342
+ if use_cosmology:
295
343
  z_start = self.params["time"]["start"]
296
344
  z_end = self.params["time"]["end"]
297
345
  omega_matter = self.params["cosmology"]["omega_matter"]
@@ -303,6 +351,17 @@ class Simulation:
303
351
  state["t"] = 0.0
304
352
  state["redshift"] = z_start
305
353
 
354
+ # self-interaction
355
+ a_s = self.scattering_length
356
+ c = constants["speed_of_light"]
357
+ m = self.axion_mass
358
+ si_coeff = None
359
+ si_coeff2 = None
360
+ do_self_interaction = a_s != 0.0
361
+ if do_self_interaction:
362
+ si_coeff = (4.0 * jnp.pi) * (a_s / m) / m_per_hbar**2
363
+ si_coeff2 = (32.0 * jnp.pi**2 / 3.0) * (a_s / (m * c)) ** 2 / m_per_hbar**5
364
+
306
365
  # hydro
307
366
  c_sound = self.params["hydro"]["sound_speed"]
308
367
 
@@ -317,16 +376,14 @@ class Simulation:
317
376
  k_sq = None
318
377
 
319
378
  # Fourier space variables
320
- if self.params["physics"]["gravity"] or self.params["physics"]["quantum"]:
379
+ if use_gravity or use_quantum:
321
380
  kx, ky, kz = self.kgrid
322
381
  k_sq = kx**2 + ky**2 + kz**2
323
382
 
324
383
  # Checkpointer
325
- if self.params["output"]["save"]:
384
+ if save:
326
385
  options = ocp.CheckpointManagerOptions()
327
- checkpoint_dir = checkpoint_dir = os.path.join(
328
- os.getcwd(), self.params["output"]["path"]
329
- )
386
+ checkpoint_dir = os.path.join(os.getcwd(), self.params["output"]["path"])
330
387
  path = os.path.join(os.getcwd(), checkpoint_dir)
331
388
  if jax.process_index() == 0:
332
389
  path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
@@ -336,80 +393,86 @@ class Simulation:
336
393
 
337
394
  def _kick(state, kx, ky, kz, k_sq, dt):
338
395
  # Kick (half-step)
339
- if (
340
- self.params["physics"]["gravity"]
341
- and self.params["physics"]["external_potential"]
342
- ):
396
+ if use_gravity and use_external_potential:
343
397
  V = self._calc_grav_potential(state, k_sq) + state["V_ext"]
344
- elif self.params["physics"]["gravity"]:
398
+ elif use_gravity:
345
399
  V = self._calc_grav_potential(state, k_sq)
346
- elif self.params["physics"]["external_potential"]:
400
+ elif use_external_potential:
347
401
  V = state["V_ext"]
348
402
 
349
- if (
350
- self.params["physics"]["gravity"]
351
- or self.params["physics"]["external_potential"]
352
- ):
353
- if self.params["physics"]["quantum"]:
354
- state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
355
- if self.params["physics"]["hydro"]:
403
+ if use_gravity or use_external_potential:
404
+ if use_quantum:
405
+ if do_self_interaction:
406
+ rho = jnp.abs(state["psi"]) ** 2
407
+ V_prime = V + si_coeff * rho + si_coeff2 * rho**2
408
+ state["psi"] = quantum_kick(
409
+ state["psi"], V_prime, m_per_hbar, dt
410
+ )
411
+ else:
412
+ state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
413
+ if use_hydro:
356
414
  state["vx"], state["vy"], state["vz"] = hydro_accelerate(
357
415
  state["vx"], state["vy"], state["vz"], V, kx, ky, kz, dt
358
416
  )
359
- if self.params["physics"]["particles"]:
417
+ if use_particles:
360
418
  state["vel"] = particles_accelerate(
361
419
  state["vel"], state["pos"], V, kx, ky, kz, dx, dt
362
420
  )
421
+ if use_custom:
422
+ state = custom_kick(state, V, dt)
423
+
424
+ return state
363
425
 
364
426
  def _drift(state, k_sq, dt):
365
427
  # Drift (full-step)
366
- if self.params["physics"]["quantum"]:
428
+ if use_quantum:
367
429
  state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
368
- if self.params["physics"]["hydro"]:
430
+ if use_hydro:
369
431
  state["rho"], state["vx"], state["vy"], state["vz"] = hydro_fluxes(
370
432
  state["rho"], state["vx"], state["vy"], state["vz"], dt, dx, c_sound
371
433
  )
372
- if self.params["physics"]["particles"]:
434
+ if use_particles:
373
435
  state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
436
+ if use_custom:
437
+ state = custom_drift(state, k_sq, dt)
438
+
439
+ return state
374
440
 
375
- @jax.jit
376
441
  def _update(_, carry):
377
442
  # Update the simulation state by one timestep
378
443
  # according to a 2nd-order `kick-drift-kick` scheme
379
444
  state, kx, ky, kz, k_sq = carry
380
- _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
381
- _drift(state, k_sq, dt)
445
+ state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
446
+ state = _drift(state, k_sq, dt)
382
447
  # update time & redshift
383
448
  state["t"] += dt
384
- if self.params["physics"]["cosmology"]:
449
+ if use_cosmology:
385
450
  scale_factor = get_next_scale_factor(
386
- state["redshift"],
387
- dt,
388
- self.params["cosmology"]["omega_matter"],
389
- self.params["cosmology"]["omega_lambda"],
390
- self.params["cosmology"]["little_h"],
451
+ state["redshift"], dt, omega_matter, omega_lambda, little_h
391
452
  )
392
453
  state["redshift"] = 1.0 / scale_factor - 1.0
393
- _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
454
+ state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
394
455
 
395
456
  return state, kx, ky, kz, k_sq
396
457
 
397
458
  # save initial state
398
459
  if jax.process_index() == 0:
399
460
  print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
400
- if self.params["output"]["save"]:
461
+ if save:
401
462
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
402
463
  json.dump(self.params, f, indent=2)
403
464
  async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
404
465
  plot_sim(state, checkpoint_dir, 0, self.params)
466
+ if self.custom_plot is not None:
467
+ self.custom_plot(state, checkpoint_dir, 0, self.params)
405
468
  async_checkpoint_manager.wait_until_finished()
406
469
 
407
470
  # Simulation Main Loop
408
471
  t_start_timer = time.time()
409
- if self.params["output"]["save"]:
472
+ if save:
410
473
  for i in range(1, num_checkpoints + 1):
411
474
  carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
412
- state, kx, ky, kz, k_sq = carry
475
+ state, _, _, _, _ = carry
413
476
  jax.block_until_ready(state)
414
477
  # save state
415
478
  async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
@@ -417,17 +480,18 @@ class Simulation:
417
480
  elapsed = time.time() - t_start_timer
418
481
  est_total = elapsed / i * num_checkpoints
419
482
  est_remaining = est_total - elapsed
420
- num_cells = self.resolution**3
421
483
  mcups = (num_cells * (i * nt_sub)) / (elapsed * 1.0e6)
422
484
  if jax.process_index() == 0:
423
485
  print(
424
486
  f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
425
487
  )
426
488
  plot_sim(state, checkpoint_dir, i, self.params)
489
+ if self.custom_plot is not None:
490
+ self.custom_plot(state, checkpoint_dir, i, self.params)
427
491
  async_checkpoint_manager.wait_until_finished()
428
492
  else:
429
493
  carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
430
- state, kx, ky, kz, k_sq = carry
494
+ state, _, _, _, _ = carry
431
495
  jax.block_until_ready(state)
432
496
  if jax.process_index() == 0:
433
497
  print("Simulation Run Time (s): ", time.time() - t_start_timer)
jaxion/utils.py CHANGED
@@ -31,7 +31,7 @@ def print_distributed_info():
31
31
 
32
32
  def set_up_parameters(user_overwrites):
33
33
  # first load the default params
34
- params_path = importlib.resources.files("jaxion") / "params_default.json"
34
+ params_path = importlib.resources.files("jaxion") / "defaults.json"
35
35
  with params_path.open("r", encoding="utf-8") as f:
36
36
  params = json.load(f)
37
37
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -68,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
68
68
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
69
69
 
70
70
 
71
- ## Getting started
71
+ ## Install Jaxion
72
72
 
73
73
  Install with:
74
74
 
@@ -87,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
87
87
 
88
88
  ## Examples
89
89
 
90
- Check out the `examples/` directory for demonstrations of using Jaxion.
90
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
91
91
 
92
92
  <p align="center">
93
93
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -2,16 +2,16 @@ jaxion/__init__.py,sha256=Hdji1UQ47lG24Pqcy6UUq9L0-qy6m9Ax41L0vIYzBto,164
2
2
  jaxion/analysis.py,sha256=4YT9Z2dkFoXwft3fQM1HyynVPlIdtRd80VtI2vWTyq4,1568
3
3
  jaxion/constants.py,sha256=HyY2ktKQakv78jD1yQvFdM3sklUJcPgDMYlTsSPQTxI,512
4
4
  jaxion/cosmology.py,sha256=UC1McXNTXGoPRYXn0nI2-csVkJWL-ZBNoCa44oU1b4w,2681
5
- jaxion/gravity.py,sha256=3brRZelKm-soXqk_Lt3SqhbZ00woJCraqwdMuR-KooA,291
5
+ jaxion/defaults.json,sha256=EqFHV9HlLIRvTJrfbT5-AI0GfCLBd0ViwTeUV_B8YIw,2883
6
+ jaxion/gravity.py,sha256=2smqy_jjmr0VkMGzLPMYjLahHBMfZ2nNUx0gUiTWcDI,293
6
7
  jaxion/hydro.py,sha256=KoJ02tRpAc4V3Ofzw4zbHLRaE2GdIatbOBE04_LsSRw,6980
7
- jaxion/params_default.json,sha256=9CJrhEPsv5zGEs7_WqFyuccCDipPCDhXgKzVdqOsOWE,2775
8
8
  jaxion/particles.py,sha256=pMopGvoZ0J_3EviD0WnTMmiebU9h2_8IO-p6I-E5DEU,3980
9
9
  jaxion/quantum.py,sha256=GWOpN6ipfEw-6Ah2zQpxS3oqeSt_iHMDSgnVYSjXY5E,3321
10
- jaxion/simulation.py,sha256=2YkHh3tUVvg5tUNnOxf4s7wGeuMntYUJcJWV0M-3Pl8,16267
11
- jaxion/utils.py,sha256=f8SvJjqzcW2K91qbPNqrsjfVjyPShuf50yoSHc0YqYE,4093
10
+ jaxion/simulation.py,sha256=rxncry9xPAVpohGkLqSo9vft-qaGTnh4f-wKk-A-Y2A,18496
11
+ jaxion/utils.py,sha256=OCpJ3crZqr5VFacymYzi5BkRqFCVBcneoS44wV9mZPg,4087
12
12
  jaxion/visualization.py,sha256=Vx_xuEE7BB1AkEGaY9KHNSIIDxJzUVT104o-3uglW8o,2966
13
- jaxion-0.0.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
- jaxion-0.0.5.dist-info/METADATA,sha256=Ot_IWpsCk6l7KX94fCs4j1LLcJQLlQFmp-jYCyj_yWY,6378
15
- jaxion-0.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- jaxion-0.0.5.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
17
- jaxion-0.0.5.dist-info/RECORD,,
13
+ jaxion-0.0.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
+ jaxion-0.0.6.dist-info/METADATA,sha256=uNykshiiqSIPVw2UM44_zREemfnkaJlGdVqZKyi7X-8,6440
15
+ jaxion-0.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ jaxion-0.0.6.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
17
+ jaxion-0.0.6.dist-info/RECORD,,
File without changes