jaxion 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. jaxion-0.0.4/PKG-INFO +133 -0
  2. jaxion-0.0.4/README.md +110 -0
  3. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/__init__.py +1 -0
  4. jaxion-0.0.4/jaxion/analysis.py +58 -0
  5. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/constants.py +8 -8
  6. jaxion-0.0.4/jaxion/cosmology.py +74 -0
  7. jaxion-0.0.4/jaxion/gravity.py +10 -0
  8. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/hydro.py +5 -4
  9. jaxion-0.0.4/jaxion/params_default.json +114 -0
  10. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/particles.py +5 -4
  11. jaxion-0.0.4/jaxion/quantum.py +95 -0
  12. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/simulation.py +185 -39
  13. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/utils.py +27 -0
  14. jaxion-0.0.4/jaxion/visualization.py +90 -0
  15. jaxion-0.0.4/jaxion.egg-info/PKG-INFO +133 -0
  16. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/SOURCES.txt +2 -0
  17. {jaxion-0.0.2 → jaxion-0.0.4}/pyproject.toml +4 -0
  18. {jaxion-0.0.2 → jaxion-0.0.4}/tests/test_examples.py +17 -9
  19. jaxion-0.0.2/PKG-INFO +0 -76
  20. jaxion-0.0.2/README.md +0 -53
  21. jaxion-0.0.2/jaxion/gravity.py +0 -9
  22. jaxion-0.0.2/jaxion/params_default.json +0 -77
  23. jaxion-0.0.2/jaxion/quantum.py +0 -15
  24. jaxion-0.0.2/jaxion/visualization.py +0 -74
  25. jaxion-0.0.2/jaxion.egg-info/PKG-INFO +0 -76
  26. {jaxion-0.0.2 → jaxion-0.0.4}/LICENSE +0 -0
  27. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/dependency_links.txt +0 -0
  28. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/requires.txt +0 -0
  29. {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/top_level.txt +0 -0
  30. {jaxion-0.0.2 → jaxion-0.0.4}/requirements.txt +0 -0
  31. {jaxion-0.0.2 → jaxion-0.0.4}/setup.cfg +0 -0
jaxion-0.0.4/PKG-INFO ADDED
@@ -0,0 +1,133 @@
1
+ Metadata-Version: 2.4
2
+ Name: jaxion
3
+ Version: 0.0.4
4
+ Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
+ Author-email: Philip Mocz <philip.mocz@gmail.com>
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Documentation, https://jaxion.readthedocs.io
8
+ Project-URL: Homepage, https://github.com/JaxionProject/jaxion
9
+ Requires-Python: >=3.11
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: jax==0.5.3
13
+ Requires-Dist: jaxdecomp==0.2.7
14
+ Requires-Dist: tensorflow
15
+ Requires-Dist: orbax-checkpoint==0.11.18
16
+ Requires-Dist: optax==0.2.5
17
+ Requires-Dist: numpy
18
+ Requires-Dist: matplotlib
19
+ Requires-Dist: setuptools>=70.1.1
20
+ Provides-Extra: cuda12
21
+ Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
22
+ Dynamic: license-file
23
+
24
+ <p align="center">
25
+ <a href="https://jaxion.readthedocs.io">
26
+ <img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
27
+ </a>
28
+ </p>
29
+
30
+ # jaxion
31
+
32
+ [![Repo Status][status-badge]][status-link]
33
+ [![PyPI Version Status][pypi-badge]][pypi-link]
34
+ [![Test Status][workflow-test-badge]][workflow-test-link]
35
+ [![Coverage][coverage-badge]][coverage-link]
36
+ [![Readthedocs Status][docs-badge]][docs-link]
37
+ [![License][license-badge]][license-link]
38
+
39
+ [status-link]: https://www.repostatus.org/#active
40
+ [status-badge]: https://www.repostatus.org/badges/latest/active.svg
41
+ [pypi-link]: https://pypi.org/project/jaxion
42
+ [pypi-badge]: https://img.shields.io/pypi/v/jaxion?label=PyPI&logo=pypi
43
+ [workflow-test-link]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml
44
+ [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
45
+ [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
46
+ [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
47
+ [docs-link]: https://jaxion.readthedocs.io
48
+ [docs-badge]: https://readthedocs.org/projects/jaxion/badge
49
+ [license-link]: https://opensource.org/licenses/Apache-2.0
50
+ [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
51
+
52
+ A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
53
+
54
+ Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
55
+
56
+ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
57
+
58
+ Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
59
+
60
+
61
+ ## Getting started
62
+
63
+ Install with:
64
+
65
+ ```console
66
+ pip install jaxion
67
+ ```
68
+
69
+ or, for GPU support, use:
70
+
71
+ ```console
72
+ pip install jaxion[cuda12]
73
+ ```
74
+
75
+ See the docs for more info on how to [build from source](https://jaxion.readthedocs.io/en/latest/pages/installation.html).
76
+
77
+
78
+ ## Examples
79
+
80
+ Check out the `examples/` directory for demonstrations of using Jaxion.
81
+
82
+ <p align="center">
83
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
84
+ <img src="examples/cosmological_box/movie.gif" alt="cosmological_box" width="128"/>
85
+ </a>
86
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
87
+ <img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
88
+ </a>
89
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
90
+ <img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
91
+ </a>
92
+ <br>
93
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_stars">
94
+ <img src="examples/heating_stars/movie.gif" alt="heating_stars" width="128"/>
95
+ </a>
96
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
97
+ <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
98
+ </a>
99
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
100
+ <img src="examples/logo/movie.gif" alt="logo" width="128"/>
101
+ </a>
102
+ <br>
103
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
104
+ <img src="examples/soliton_binary_merger/movie.gif" alt="soliton_binary_merger" width="128"/>
105
+ </a>
106
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_merger">
107
+ <img src="examples/soliton_merger/movie.gif" alt="soliton_merger" width="128"/>
108
+ </a>
109
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
110
+ <img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
111
+ </a>
112
+ </p>
113
+
114
+
115
+ ## Links
116
+
117
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
118
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
119
+
120
+
121
+ ## Testing
122
+
123
+ Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
124
+
125
+
126
+ ## Contributing
127
+
128
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
129
+
130
+
131
+ ## Cite this repository
132
+
133
+ TODO XXX
jaxion-0.0.4/README.md ADDED
@@ -0,0 +1,110 @@
1
+ <p align="center">
2
+ <a href="https://jaxion.readthedocs.io">
3
+ <img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
4
+ </a>
5
+ </p>
6
+
7
+ # jaxion
8
+
9
+ [![Repo Status][status-badge]][status-link]
10
+ [![PyPI Version Status][pypi-badge]][pypi-link]
11
+ [![Test Status][workflow-test-badge]][workflow-test-link]
12
+ [![Coverage][coverage-badge]][coverage-link]
13
+ [![Readthedocs Status][docs-badge]][docs-link]
14
+ [![License][license-badge]][license-link]
15
+
16
+ [status-link]: https://www.repostatus.org/#active
17
+ [status-badge]: https://www.repostatus.org/badges/latest/active.svg
18
+ [pypi-link]: https://pypi.org/project/jaxion
19
+ [pypi-badge]: https://img.shields.io/pypi/v/jaxion?label=PyPI&logo=pypi
20
+ [workflow-test-link]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml
21
+ [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
22
+ [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
23
+ [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
24
+ [docs-link]: https://jaxion.readthedocs.io
25
+ [docs-badge]: https://readthedocs.org/projects/jaxion/badge
26
+ [license-link]: https://opensource.org/licenses/Apache-2.0
27
+ [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
28
+
29
+ A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
30
+
31
+ Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
32
+
33
+ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
34
+
35
+ Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
36
+
37
+
38
+ ## Getting started
39
+
40
+ Install with:
41
+
42
+ ```console
43
+ pip install jaxion
44
+ ```
45
+
46
+ or, for GPU support, use:
47
+
48
+ ```console
49
+ pip install jaxion[cuda12]
50
+ ```
51
+
52
+ See the docs for more info on how to [build from source](https://jaxion.readthedocs.io/en/latest/pages/installation.html).
53
+
54
+
55
+ ## Examples
56
+
57
+ Check out the `examples/` directory for demonstrations of using Jaxion.
58
+
59
+ <p align="center">
60
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
61
+ <img src="examples/cosmological_box/movie.gif" alt="cosmological_box" width="128"/>
62
+ </a>
63
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
64
+ <img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
65
+ </a>
66
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
67
+ <img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
68
+ </a>
69
+ <br>
70
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_stars">
71
+ <img src="examples/heating_stars/movie.gif" alt="heating_stars" width="128"/>
72
+ </a>
73
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
74
+ <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
75
+ </a>
76
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
77
+ <img src="examples/logo/movie.gif" alt="logo" width="128"/>
78
+ </a>
79
+ <br>
80
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
81
+ <img src="examples/soliton_binary_merger/movie.gif" alt="soliton_binary_merger" width="128"/>
82
+ </a>
83
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_merger">
84
+ <img src="examples/soliton_merger/movie.gif" alt="soliton_merger" width="128"/>
85
+ </a>
86
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
87
+ <img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
88
+ </a>
89
+ </p>
90
+
91
+
92
+ ## Links
93
+
94
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
95
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
96
+
97
+
98
+ ## Testing
99
+
100
+ Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
101
+
102
+
103
+ ## Contributing
104
+
105
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
106
+
107
+
108
+ ## Cite this repository
109
+
110
+ TODO XXX
@@ -1,2 +1,3 @@
1
1
  from .simulation import Simulation as Simulation
2
2
  from .constants import constants as constants
3
+ from .analysis import radial_power_spectrum as radial_power_spectrum
@@ -0,0 +1,58 @@
1
+ import jax.numpy as jnp
2
+ import jaxdecomp as jd
3
+
4
+
5
+ def radial_power_spectrum(data_cube, kx, ky, kz, box_size):
6
+ """
7
+ Computes the radially averaged power spectral density of a 3D datacube.
8
+
9
+ Parameters
10
+ ----------
11
+ data_cube : jnp.ndarray
12
+ 3D data cube, must be cubic
13
+ kx, ky, kz: jnp.ndarray
14
+ wavenumber grids in each dimension
15
+ box_size: float
16
+ physical size of box
17
+
18
+ Returns
19
+ -------
20
+ Pf: jnp.ndarray
21
+ radial power spectrum
22
+ k: jnp.ndarray
23
+ wavenumbers
24
+ total_power: float
25
+ total power
26
+ """
27
+ dim = data_cube.ndim
28
+ nx = data_cube.shape[0]
29
+ dx = box_size / nx
30
+
31
+ # Compute power spectrum
32
+ data_cube_hat = jd.fft.pfft3d(data_cube)
33
+ total_power = 0.5 * jnp.sum(jnp.abs(data_cube_hat) ** 2) / nx**dim * dx**dim
34
+ phi_k = 0.5 * jnp.abs(data_cube_hat) ** 2 / nx**dim * dx**dim
35
+ half_size = nx // 2 + 1
36
+
37
+ # Compute radially-averaged power spectrum
38
+ # if dim == 2:
39
+ # k_r = jnp.sqrt(kx**2 + ky**2)
40
+ k_r = jnp.sqrt(kx**2 + ky**2 + kz**2)
41
+
42
+ Pf, _ = jnp.histogram(
43
+ k_r, range=(-0.5, half_size - 0.5), bins=half_size, weights=phi_k
44
+ )
45
+ norm, _ = jnp.histogram(k_r, range=(-0.5, half_size - 0.5), bins=half_size)
46
+ Pf /= norm + (norm == 0)
47
+
48
+ k = 2.0 * jnp.pi * jnp.arange(half_size) / box_size
49
+ dk = 2.0 * jnp.pi / box_size
50
+
51
+ Pf /= dk**dim
52
+
53
+ # Add geometrical factor
54
+ # if dim == 2:
55
+ # Pf = Pf * 2.0 * jnp.pi * k
56
+ Pf *= 4.0 * jnp.pi * k**2
57
+
58
+ return Pf, k, total_power
@@ -1,13 +1,13 @@
1
1
  # Copyright (c) 2025 The Jaxion Team.
2
2
 
3
- # Physical Constants
4
- #
5
- # Jaxion uses a unit system of:
6
- # [L] = kpc
7
- # [V] = km/s
8
- # [M] = Msun
9
- #
10
- # other units are derived from these base units, e.g., [T] = [L]/[V] = kpc / (km/s) ~= 0.978 Gyr
3
+ """
4
+ Physical constants in units of:
5
+ [L] = kpc,
6
+ [V] = km/s,
7
+ [M] = Msun
8
+
9
+ note: other units are derived from these base units, e.g., [T] = [L]/[V] = kpc / (km/s) ~= 0.978 Gyr
10
+ """
11
11
 
12
12
  constants = {
13
13
  "gravitational_constant": 4.30241002e-6, # G / (kpc * (km/s)^2 / Msun)
@@ -0,0 +1,74 @@
1
+ import jax.numpy as jnp
2
+
3
+ # Pure functions for cosmology simulation
4
+
5
+
6
+ def get_physical_time_interval(z_start, z_end, omega_matter, omega_lambda, little_h):
7
+ """Compute the total physical time between two redshifts."""
8
+ # da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
9
+ a_start = 1.0 / (1.0 + z_start)
10
+ a_end = 1.0 / (1.0 + z_end)
11
+ H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
12
+ n_quad = 10000
13
+ a_lin = jnp.linspace(a_start, a_end, n_quad)
14
+ a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
15
+ dt_hat = jnp.trapezoid(1.0 / a_dot, a_lin)
16
+ return dt_hat
17
+
18
+
19
+ def get_supercomoving_time_interval(
20
+ z_start, z_end, omega_matter, omega_lambda, little_h
21
+ ):
22
+ """Compute the total supercomoving time (dt_hat = a^-2 dt) between two redshifts."""
23
+ # da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
24
+ # da/dt_hat = a^2 * da/dt
25
+ # dt_hat/da = a^-2 / (da/dt)
26
+ a_start = 1.0 / (1.0 + z_start)
27
+ a_end = 1.0 / (1.0 + z_end)
28
+ H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
29
+ n_quad = 10000
30
+ a_lin = jnp.linspace(a_start, a_end, n_quad)
31
+ a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
32
+ dt_hat_da = a_lin**-2 / a_dot
33
+ dt_hat = jnp.trapezoid(dt_hat_da, a_lin)
34
+ return dt_hat
35
+
36
+
37
+ def get_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
38
+ """Compute the scale factor corresponding to a given supercomoving time,
39
+ by root finding."""
40
+ a_start = 1.0 / (1.0 + z_start)
41
+ a = a_start
42
+ tolerance = 1e-6
43
+ max_iterations = 100
44
+ lower_bound = 0.0
45
+ upper_bound = 2.0 # Set a reasonable upper bound for the scale factor
46
+ for _ in range(max_iterations):
47
+ dt_hat_guess = get_supercomoving_time_interval(
48
+ z_start, 1.0 / a - 1.0, omega_matter, omega_lambda, little_h
49
+ )
50
+ error = dt_hat_guess - dt_hat
51
+ if jnp.abs(error) < tolerance:
52
+ break
53
+ # Use bisection method
54
+ if error > 0:
55
+ upper_bound = a # Move upper bound down
56
+ else:
57
+ lower_bound = a # Move lower bound up
58
+ a = (lower_bound + upper_bound) / 2 # Update scale factor to midpoint
59
+ return a
60
+
61
+
62
+ def get_next_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
63
+ """Advance scale factor by dt_hat using RK2 (midpoint) on da/dt_hat = a^2 * da/dt."""
64
+ a_start = 1.0 / (1.0 + z_start)
65
+ H0 = 0.1 * little_h
66
+
67
+ def g(a):
68
+ a_dot = H0 * jnp.sqrt((omega_matter / a) + (omega_lambda * a**2))
69
+ return a**2 * a_dot # da/dt_hat
70
+
71
+ k1 = g(a_start)
72
+ k2 = g(a_start + 0.5 * dt_hat * k1)
73
+ a = a_start + dt_hat * k2
74
+ return a
@@ -0,0 +1,10 @@
1
+ import jax.numpy as jnp
2
+ import jaxdecomp as jd
3
+
4
+ # Pure functions for gravity calculations
5
+
6
+
7
+ def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
8
+ Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
+ V = jnp.real(jd.fft.pifft3d(Vhat))
10
+ return V
@@ -1,4 +1,5 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for hydro simulation
4
5
 
@@ -89,11 +90,11 @@ def get_flux(rho_L, vx_L, vy_L, vz_L, rho_R, vx_R, vy_R, vz_R, cs):
89
90
 
90
91
 
91
92
  def hydro_accelerate(vx, vy, vz, V, kx, ky, kz, dt):
92
- V_hat = jnp.fft.fftn(V)
93
+ V_hat = jd.fft.pfft3d(V)
93
94
 
94
- ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
95
- ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
96
- az = -jnp.real(jnp.fft.ifftn(1.0j * kz * V_hat))
95
+ ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
96
+ ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
97
+ az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
97
98
 
98
99
  vx += ax * dt
99
100
  vy += ay * dt
@@ -0,0 +1,114 @@
1
+ {
2
+ "physics": {
3
+ "quantum": {
4
+ "default": true,
5
+ "description": "switch on for fuzzy dark matter."
6
+ },
7
+ "gravity": {
8
+ "default": true,
9
+ "description": "switch on for self-gravity."
10
+ },
11
+ "hydro": {
12
+ "default": false,
13
+ "description": "switch on for hydrodynamics (isothermal)."
14
+ },
15
+ "particles": {
16
+ "default": false,
17
+ "description": "switch on for particle-mesh particles."
18
+ },
19
+ "cosmology": {
20
+ "default": false,
21
+ "description": "switch on for cosmological factors/comoving units."
22
+ },
23
+ "external_potential": {
24
+ "default": false,
25
+ "description": "switch on for external gravitational potential."
26
+ }
27
+ },
28
+ "domain": {
29
+ "box_size": {
30
+ "default": 10.0,
31
+ "description": "periodic domain box size [kpc]."
32
+ },
33
+ "resolution_base": {
34
+ "default": 32,
35
+ "description": "base resolution per linear dimension."
36
+ },
37
+ "resolution_multiplier": {
38
+ "default": 1,
39
+ "description": "resolution multiplier."
40
+ }
41
+ },
42
+ "time": {
43
+ "start": {
44
+ "default": 0.0,
45
+ "description": "simulation start time [kpc/(km/s)] or [redshift] (cosmology=true)."
46
+ },
47
+ "end": {
48
+ "default": 1.0,
49
+ "description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
50
+ },
51
+ "adaptive": {
52
+ "default": false,
53
+ "description": "switch on for adaptive time stepping."
54
+ }
55
+ },
56
+ "output": {
57
+ "path": {
58
+ "default": "./checkpoints",
59
+ "description": "path to output directory."
60
+ },
61
+ "num_checkpoints": {
62
+ "default": 100,
63
+ "description": "number of checkpoints to save."
64
+ },
65
+ "save": {
66
+ "default": true,
67
+ "description": "switch on to save checkpoints."
68
+ },
69
+ "plot_dynamic_range": {
70
+ "default": 100.0,
71
+ "description": "dynamic range for plotting."
72
+ }
73
+ },
74
+ "quantum": {
75
+ "m_22": {
76
+ "default": 1.0,
77
+ "description": "axion mass [10^{-22} eV]."
78
+ }
79
+ },
80
+ "hydro": {
81
+ "sound_speed": {
82
+ "default": 1.0,
83
+ "description": "isothermal sound speed [km/s]."
84
+ }
85
+ },
86
+ "particles": {
87
+ "num_particles": {
88
+ "default": 0,
89
+ "description": "number of particles."
90
+ },
91
+ "particle_mass": {
92
+ "default": 1.0,
93
+ "description": "particle mass [M_sun]."
94
+ }
95
+ },
96
+ "cosmology": {
97
+ "omega_matter": {
98
+ "default": 0.3,
99
+ "description": "matter density parameter."
100
+ },
101
+ "omega_lambda": {
102
+ "default": 0.7,
103
+ "description": "dark energy density parameter."
104
+ },
105
+ "little_h": {
106
+ "default": 0.7,
107
+ "description": "Hubble parameter little h (H_0 = 100 h km/s/Mpc)."
108
+ }
109
+ },
110
+ "version": {
111
+ "default": "unknown",
112
+ "description": "jaxion version used (auto detected)."
113
+ }
114
+ }
@@ -1,5 +1,6 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
+ import jaxdecomp as jd
3
4
 
4
5
  # Pure functions for particle-mesh calculations (stars, BHs, ...)
5
6
 
@@ -64,10 +65,10 @@ def get_acceleration(pos, V, kx, ky, kz, dx):
64
65
  i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
65
66
 
66
67
  # find accelerations on the grid
67
- V_hat = jnp.fft.fftn(V)
68
- ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
69
- ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
70
- az = -jnp.real(jnp.fft.ifftn(1.0j * kz * V_hat))
68
+ V_hat = jd.fft.pfft3d(V)
69
+ ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
70
+ ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
71
+ az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
71
72
  a_grid = jnp.stack((ax, ay, az), axis=-1)
72
73
 
73
74
  # interpolate the accelerations to the particle positions
@@ -0,0 +1,95 @@
1
+ import jax.numpy as jnp
2
+ import jaxdecomp as jd
3
+
4
+ # Pure functions for quantum simulation
5
+
6
+
7
+ def quantum_kick(psi, V, m_per_hbar, dt):
8
+ psi = jnp.exp(-1.0j * m_per_hbar * dt * V) * psi
9
+ return psi
10
+
11
+
12
+ def quantum_drift(psi, k_sq, m_per_hbar, dt):
13
+ psi_hat = jd.fft.pfft3d(psi)
14
+ psi_hat = jnp.exp(dt * (-1.0j * k_sq / m_per_hbar / 2.0)) * psi_hat
15
+ psi = jd.fft.pifft3d(psi_hat)
16
+ return psi
17
+
18
+
19
+ def get_gradient(psi, kx, ky, kz):
20
+ """
21
+ Computes gradient of wavefunction psi
22
+ psi: jnp.ndarray (3D)
23
+ Returns: grad_psi_x, grad_psi_y, grad_psi_z
24
+ """
25
+ psi_hat = jd.fft.pfft3d(psi)
26
+ grad_psi_x_hat = 1.0j * kx * psi_hat
27
+ grad_psi_y_hat = 1.0j * ky * psi_hat
28
+ grad_psi_z_hat = 1.0j * kz * psi_hat
29
+
30
+ grad_psi_x = jd.fft.pifft3d(grad_psi_x_hat)
31
+ grad_psi_y = jd.fft.pifft3d(grad_psi_y_hat)
32
+ grad_psi_z = jd.fft.pifft3d(grad_psi_z_hat)
33
+
34
+ return grad_psi_x, grad_psi_y, grad_psi_z
35
+
36
+
37
+ def quantum_velocity(psi, box_size, m_per_hbar):
38
+ """
39
+ Compute the velocity from the wave-function
40
+ v = nabla S / m
41
+ psi = sqrt(rho) exp(i S / hbar)
42
+ """
43
+ N = psi.shape[0]
44
+ dx = box_size / N
45
+
46
+ S_per_hbar = jnp.angle(psi)
47
+
48
+ # Central differences with phase unwrapping
49
+ # vx = jnp.roll(S_per_hbar, -1, axis=0) - jnp.roll(S_per_hbar, 1, axis=0)
50
+ # vy = jnp.roll(S_per_hbar, -1, axis=1) - jnp.roll(S_per_hbar, 1, axis=1)
51
+ # vz = jnp.roll(S_per_hbar, -1, axis=2) - jnp.roll(S_per_hbar, 1, axis=2)
52
+ # vx = jnp.where(vx > jnp.pi, vx - 2 * jnp.pi, vx)
53
+ # vx = jnp.where(vx <= -jnp.pi, vx + 2 * jnp.pi, vx)
54
+ # vy = jnp.where(vy > jnp.pi, vy - 2 * jnp.pi, vy)
55
+ # vy = jnp.where(vy <= -jnp.pi, vy + 2 * jnp.pi, vy)
56
+ # vz = jnp.where(vz > jnp.pi, vz - 2 * jnp.pi, vz)
57
+ # vz = jnp.where(vz <= -jnp.pi, vz + 2 * jnp.pi, vz)
58
+ # vx = vx / (2.0 * dx) / m_per_hbar
59
+ # vy = vy / (2.0 * dx) / m_per_hbar
60
+ # vz = vz / (2.0 * dx) / m_per_hbar
61
+
62
+ # Forward differences
63
+ vx1 = jnp.roll(S_per_hbar, -1, axis=0) - S_per_hbar
64
+ vy1 = jnp.roll(S_per_hbar, -1, axis=1) - S_per_hbar
65
+ vz1 = jnp.roll(S_per_hbar, -1, axis=2) - S_per_hbar
66
+ vx1 = jnp.where(vx1 > jnp.pi, vx1 - 2 * jnp.pi, vx1)
67
+ vx1 = jnp.where(vx1 <= -jnp.pi, vx1 + 2 * jnp.pi, vx1)
68
+ vy1 = jnp.where(vy1 > jnp.pi, vy1 - 2 * jnp.pi, vy1)
69
+ vy1 = jnp.where(vy1 <= -jnp.pi, vy1 + 2 * jnp.pi, vy1)
70
+ vz1 = jnp.where(vz1 > jnp.pi, vz1 - 2 * jnp.pi, vz1)
71
+ vz1 = jnp.where(vz1 <= -jnp.pi, vz1 + 2 * jnp.pi, vz1)
72
+ vx1 = vx1 / dx / m_per_hbar
73
+ vy1 = vy1 / dx / m_per_hbar
74
+ vz1 = vz1 / dx / m_per_hbar
75
+
76
+ # Backward differences
77
+ vx2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=0)
78
+ vy2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=1)
79
+ vz2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=2)
80
+ vx2 = jnp.where(vx2 > jnp.pi, vx2 - 2 * jnp.pi, vx2)
81
+ vx2 = jnp.where(vx2 <= -jnp.pi, vx2 + 2 * jnp.pi, vx2)
82
+ vy2 = jnp.where(vy2 > jnp.pi, vy2 - 2 * jnp.pi, vy2)
83
+ vy2 = jnp.where(vy2 <= -jnp.pi, vy2 + 2 * jnp.pi, vy2)
84
+ vz2 = jnp.where(vz2 > jnp.pi, vz2 - 2 * jnp.pi, vz2)
85
+ vz2 = jnp.where(vz2 <= -jnp.pi, vz2 + 2 * jnp.pi, vz2)
86
+ vx2 = vx2 / dx / m_per_hbar
87
+ vy2 = vy2 / dx / m_per_hbar
88
+ vz2 = vz2 / dx / m_per_hbar
89
+
90
+ # Average forward and backward
91
+ vx = 0.5 * (vx1 + vx2)
92
+ vy = 0.5 * (vy1 + vy2)
93
+ vz = 0.5 * (vz1 + vz2)
94
+
95
+ return vx, vy, vz