jaxion 0.0.2__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxion-0.0.4/PKG-INFO +133 -0
- jaxion-0.0.4/README.md +110 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/__init__.py +1 -0
- jaxion-0.0.4/jaxion/analysis.py +58 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/constants.py +8 -8
- jaxion-0.0.4/jaxion/cosmology.py +74 -0
- jaxion-0.0.4/jaxion/gravity.py +10 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/hydro.py +5 -4
- jaxion-0.0.4/jaxion/params_default.json +114 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/particles.py +5 -4
- jaxion-0.0.4/jaxion/quantum.py +95 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/simulation.py +185 -39
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion/utils.py +27 -0
- jaxion-0.0.4/jaxion/visualization.py +90 -0
- jaxion-0.0.4/jaxion.egg-info/PKG-INFO +133 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/SOURCES.txt +2 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/pyproject.toml +4 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/tests/test_examples.py +17 -9
- jaxion-0.0.2/PKG-INFO +0 -76
- jaxion-0.0.2/README.md +0 -53
- jaxion-0.0.2/jaxion/gravity.py +0 -9
- jaxion-0.0.2/jaxion/params_default.json +0 -77
- jaxion-0.0.2/jaxion/quantum.py +0 -15
- jaxion-0.0.2/jaxion/visualization.py +0 -74
- jaxion-0.0.2/jaxion.egg-info/PKG-INFO +0 -76
- {jaxion-0.0.2 → jaxion-0.0.4}/LICENSE +0 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/dependency_links.txt +0 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/requires.txt +0 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/jaxion.egg-info/top_level.txt +0 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/requirements.txt +0 -0
- {jaxion-0.0.2 → jaxion-0.0.4}/setup.cfg +0 -0
jaxion-0.0.4/PKG-INFO
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jaxion
|
|
3
|
+
Version: 0.0.4
|
|
4
|
+
Summary: A differentiable simulation library for fuzzy dark matter in JAX
|
|
5
|
+
Author-email: Philip Mocz <philip.mocz@gmail.com>
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Documentation, https://jaxion.readthedocs.io
|
|
8
|
+
Project-URL: Homepage, https://github.com/JaxionProject/jaxion
|
|
9
|
+
Requires-Python: >=3.11
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: jax==0.5.3
|
|
13
|
+
Requires-Dist: jaxdecomp==0.2.7
|
|
14
|
+
Requires-Dist: tensorflow
|
|
15
|
+
Requires-Dist: orbax-checkpoint==0.11.18
|
|
16
|
+
Requires-Dist: optax==0.2.5
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: matplotlib
|
|
19
|
+
Requires-Dist: setuptools>=70.1.1
|
|
20
|
+
Provides-Extra: cuda12
|
|
21
|
+
Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
|
|
24
|
+
<p align="center">
|
|
25
|
+
<a href="https://jaxion.readthedocs.io">
|
|
26
|
+
<img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
|
|
27
|
+
</a>
|
|
28
|
+
</p>
|
|
29
|
+
|
|
30
|
+
# jaxion
|
|
31
|
+
|
|
32
|
+
[![Repo Status][status-badge]][status-link]
|
|
33
|
+
[![PyPI Version Status][pypi-badge]][pypi-link]
|
|
34
|
+
[![Test Status][workflow-test-badge]][workflow-test-link]
|
|
35
|
+
[![Coverage][coverage-badge]][coverage-link]
|
|
36
|
+
[![Readthedocs Status][docs-badge]][docs-link]
|
|
37
|
+
[![License][license-badge]][license-link]
|
|
38
|
+
|
|
39
|
+
[status-link]: https://www.repostatus.org/#active
|
|
40
|
+
[status-badge]: https://www.repostatus.org/badges/latest/active.svg
|
|
41
|
+
[pypi-link]: https://pypi.org/project/jaxion
|
|
42
|
+
[pypi-badge]: https://img.shields.io/pypi/v/jaxion?label=PyPI&logo=pypi
|
|
43
|
+
[workflow-test-link]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml
|
|
44
|
+
[workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
|
|
45
|
+
[coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
|
|
46
|
+
[coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
|
|
47
|
+
[docs-link]: https://jaxion.readthedocs.io
|
|
48
|
+
[docs-badge]: https://readthedocs.org/projects/jaxion/badge
|
|
49
|
+
[license-link]: https://opensource.org/licenses/Apache-2.0
|
|
50
|
+
[license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
|
|
51
|
+
|
|
52
|
+
A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
|
|
53
|
+
|
|
54
|
+
Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
|
|
55
|
+
|
|
56
|
+
Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
|
|
57
|
+
|
|
58
|
+
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
## Getting started
|
|
62
|
+
|
|
63
|
+
Install with:
|
|
64
|
+
|
|
65
|
+
```console
|
|
66
|
+
pip install jaxion
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
or, for GPU support, use:
|
|
70
|
+
|
|
71
|
+
```console
|
|
72
|
+
pip install jaxion[cuda12]
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
See the docs for more info on how to [build from source](https://jaxion.readthedocs.io/en/latest/pages/installation.html).
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
## Examples
|
|
79
|
+
|
|
80
|
+
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
81
|
+
|
|
82
|
+
<p align="center">
|
|
83
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
|
|
84
|
+
<img src="examples/cosmological_box/movie.gif" alt="cosmological_box" width="128"/>
|
|
85
|
+
</a>
|
|
86
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
|
|
87
|
+
<img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
|
|
88
|
+
</a>
|
|
89
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
|
|
90
|
+
<img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
|
|
91
|
+
</a>
|
|
92
|
+
<br>
|
|
93
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_stars">
|
|
94
|
+
<img src="examples/heating_stars/movie.gif" alt="heating_stars" width="128"/>
|
|
95
|
+
</a>
|
|
96
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
|
|
97
|
+
<img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
|
|
98
|
+
</a>
|
|
99
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
|
|
100
|
+
<img src="examples/logo/movie.gif" alt="logo" width="128"/>
|
|
101
|
+
</a>
|
|
102
|
+
<br>
|
|
103
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
|
|
104
|
+
<img src="examples/soliton_binary_merger/movie.gif" alt="soliton_binary_merger" width="128"/>
|
|
105
|
+
</a>
|
|
106
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_merger">
|
|
107
|
+
<img src="examples/soliton_merger/movie.gif" alt="soliton_merger" width="128"/>
|
|
108
|
+
</a>
|
|
109
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
|
|
110
|
+
<img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
|
|
111
|
+
</a>
|
|
112
|
+
</p>
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
## Links
|
|
116
|
+
|
|
117
|
+
* [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
|
|
118
|
+
* [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
## Testing
|
|
122
|
+
|
|
123
|
+
Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
## Contributing
|
|
127
|
+
|
|
128
|
+
Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
## Cite this repository
|
|
132
|
+
|
|
133
|
+
TODO XXX
|
jaxion-0.0.4/README.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<a href="https://jaxion.readthedocs.io">
|
|
3
|
+
<img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
|
|
4
|
+
</a>
|
|
5
|
+
</p>
|
|
6
|
+
|
|
7
|
+
# jaxion
|
|
8
|
+
|
|
9
|
+
[![Repo Status][status-badge]][status-link]
|
|
10
|
+
[![PyPI Version Status][pypi-badge]][pypi-link]
|
|
11
|
+
[![Test Status][workflow-test-badge]][workflow-test-link]
|
|
12
|
+
[![Coverage][coverage-badge]][coverage-link]
|
|
13
|
+
[![Readthedocs Status][docs-badge]][docs-link]
|
|
14
|
+
[![License][license-badge]][license-link]
|
|
15
|
+
|
|
16
|
+
[status-link]: https://www.repostatus.org/#active
|
|
17
|
+
[status-badge]: https://www.repostatus.org/badges/latest/active.svg
|
|
18
|
+
[pypi-link]: https://pypi.org/project/jaxion
|
|
19
|
+
[pypi-badge]: https://img.shields.io/pypi/v/jaxion?label=PyPI&logo=pypi
|
|
20
|
+
[workflow-test-link]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml
|
|
21
|
+
[workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
|
|
22
|
+
[coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
|
|
23
|
+
[coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
|
|
24
|
+
[docs-link]: https://jaxion.readthedocs.io
|
|
25
|
+
[docs-badge]: https://readthedocs.org/projects/jaxion/badge
|
|
26
|
+
[license-link]: https://opensource.org/licenses/Apache-2.0
|
|
27
|
+
[license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
|
|
28
|
+
|
|
29
|
+
A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
|
|
30
|
+
|
|
31
|
+
Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
|
|
32
|
+
|
|
33
|
+
Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
|
|
34
|
+
|
|
35
|
+
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
## Getting started
|
|
39
|
+
|
|
40
|
+
Install with:
|
|
41
|
+
|
|
42
|
+
```console
|
|
43
|
+
pip install jaxion
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
or, for GPU support, use:
|
|
47
|
+
|
|
48
|
+
```console
|
|
49
|
+
pip install jaxion[cuda12]
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
See the docs for more info on how to [build from source](https://jaxion.readthedocs.io/en/latest/pages/installation.html).
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
## Examples
|
|
56
|
+
|
|
57
|
+
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
58
|
+
|
|
59
|
+
<p align="center">
|
|
60
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
|
|
61
|
+
<img src="examples/cosmological_box/movie.gif" alt="cosmological_box" width="128"/>
|
|
62
|
+
</a>
|
|
63
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
|
|
64
|
+
<img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
|
|
65
|
+
</a>
|
|
66
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
|
|
67
|
+
<img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
|
|
68
|
+
</a>
|
|
69
|
+
<br>
|
|
70
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_stars">
|
|
71
|
+
<img src="examples/heating_stars/movie.gif" alt="heating_stars" width="128"/>
|
|
72
|
+
</a>
|
|
73
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
|
|
74
|
+
<img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
|
|
75
|
+
</a>
|
|
76
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
|
|
77
|
+
<img src="examples/logo/movie.gif" alt="logo" width="128"/>
|
|
78
|
+
</a>
|
|
79
|
+
<br>
|
|
80
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
|
|
81
|
+
<img src="examples/soliton_binary_merger/movie.gif" alt="soliton_binary_merger" width="128"/>
|
|
82
|
+
</a>
|
|
83
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_merger">
|
|
84
|
+
<img src="examples/soliton_merger/movie.gif" alt="soliton_merger" width="128"/>
|
|
85
|
+
</a>
|
|
86
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
|
|
87
|
+
<img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
|
|
88
|
+
</a>
|
|
89
|
+
</p>
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
## Links
|
|
93
|
+
|
|
94
|
+
* [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
|
|
95
|
+
* [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
## Testing
|
|
99
|
+
|
|
100
|
+
Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
## Contributing
|
|
104
|
+
|
|
105
|
+
Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
## Cite this repository
|
|
109
|
+
|
|
110
|
+
TODO XXX
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def radial_power_spectrum(data_cube, kx, ky, kz, box_size):
|
|
6
|
+
"""
|
|
7
|
+
Computes the radially averaged power spectral density of a 3D datacube.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
data_cube : jnp.ndarray
|
|
12
|
+
3D data cube, must be cubic
|
|
13
|
+
kx, ky, kz: jnp.ndarray
|
|
14
|
+
wavenumber grids in each dimension
|
|
15
|
+
box_size: float
|
|
16
|
+
physical size of box
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
Pf: jnp.ndarray
|
|
21
|
+
radial power spectrum
|
|
22
|
+
k: jnp.ndarray
|
|
23
|
+
wavenumbers
|
|
24
|
+
total_power: float
|
|
25
|
+
total power
|
|
26
|
+
"""
|
|
27
|
+
dim = data_cube.ndim
|
|
28
|
+
nx = data_cube.shape[0]
|
|
29
|
+
dx = box_size / nx
|
|
30
|
+
|
|
31
|
+
# Compute power spectrum
|
|
32
|
+
data_cube_hat = jd.fft.pfft3d(data_cube)
|
|
33
|
+
total_power = 0.5 * jnp.sum(jnp.abs(data_cube_hat) ** 2) / nx**dim * dx**dim
|
|
34
|
+
phi_k = 0.5 * jnp.abs(data_cube_hat) ** 2 / nx**dim * dx**dim
|
|
35
|
+
half_size = nx // 2 + 1
|
|
36
|
+
|
|
37
|
+
# Compute radially-averaged power spectrum
|
|
38
|
+
# if dim == 2:
|
|
39
|
+
# k_r = jnp.sqrt(kx**2 + ky**2)
|
|
40
|
+
k_r = jnp.sqrt(kx**2 + ky**2 + kz**2)
|
|
41
|
+
|
|
42
|
+
Pf, _ = jnp.histogram(
|
|
43
|
+
k_r, range=(-0.5, half_size - 0.5), bins=half_size, weights=phi_k
|
|
44
|
+
)
|
|
45
|
+
norm, _ = jnp.histogram(k_r, range=(-0.5, half_size - 0.5), bins=half_size)
|
|
46
|
+
Pf /= norm + (norm == 0)
|
|
47
|
+
|
|
48
|
+
k = 2.0 * jnp.pi * jnp.arange(half_size) / box_size
|
|
49
|
+
dk = 2.0 * jnp.pi / box_size
|
|
50
|
+
|
|
51
|
+
Pf /= dk**dim
|
|
52
|
+
|
|
53
|
+
# Add geometrical factor
|
|
54
|
+
# if dim == 2:
|
|
55
|
+
# Pf = Pf * 2.0 * jnp.pi * k
|
|
56
|
+
Pf *= 4.0 * jnp.pi * k**2
|
|
57
|
+
|
|
58
|
+
return Pf, k, total_power
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
# Copyright (c) 2025 The Jaxion Team.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
3
|
+
"""
|
|
4
|
+
Physical constants in units of:
|
|
5
|
+
[L] = kpc,
|
|
6
|
+
[V] = km/s,
|
|
7
|
+
[M] = Msun
|
|
8
|
+
|
|
9
|
+
note: other units are derived from these base units, e.g., [T] = [L]/[V] = kpc / (km/s) ~= 0.978 Gyr
|
|
10
|
+
"""
|
|
11
11
|
|
|
12
12
|
constants = {
|
|
13
13
|
"gravitational_constant": 4.30241002e-6, # G / (kpc * (km/s)^2 / Msun)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
|
|
3
|
+
# Pure functions for cosmology simulation
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_physical_time_interval(z_start, z_end, omega_matter, omega_lambda, little_h):
|
|
7
|
+
"""Compute the total physical time between two redshifts."""
|
|
8
|
+
# da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
|
|
9
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
10
|
+
a_end = 1.0 / (1.0 + z_end)
|
|
11
|
+
H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
|
|
12
|
+
n_quad = 10000
|
|
13
|
+
a_lin = jnp.linspace(a_start, a_end, n_quad)
|
|
14
|
+
a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
|
|
15
|
+
dt_hat = jnp.trapezoid(1.0 / a_dot, a_lin)
|
|
16
|
+
return dt_hat
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_supercomoving_time_interval(
|
|
20
|
+
z_start, z_end, omega_matter, omega_lambda, little_h
|
|
21
|
+
):
|
|
22
|
+
"""Compute the total supercomoving time (dt_hat = a^-2 dt) between two redshifts."""
|
|
23
|
+
# da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
|
|
24
|
+
# da/dt_hat = a^2 * da/dt
|
|
25
|
+
# dt_hat/da = a^-2 / (da/dt)
|
|
26
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
27
|
+
a_end = 1.0 / (1.0 + z_end)
|
|
28
|
+
H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
|
|
29
|
+
n_quad = 10000
|
|
30
|
+
a_lin = jnp.linspace(a_start, a_end, n_quad)
|
|
31
|
+
a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
|
|
32
|
+
dt_hat_da = a_lin**-2 / a_dot
|
|
33
|
+
dt_hat = jnp.trapezoid(dt_hat_da, a_lin)
|
|
34
|
+
return dt_hat
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
|
|
38
|
+
"""Compute the scale factor corresponding to a given supercomoving time,
|
|
39
|
+
by root finding."""
|
|
40
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
41
|
+
a = a_start
|
|
42
|
+
tolerance = 1e-6
|
|
43
|
+
max_iterations = 100
|
|
44
|
+
lower_bound = 0.0
|
|
45
|
+
upper_bound = 2.0 # Set a reasonable upper bound for the scale factor
|
|
46
|
+
for _ in range(max_iterations):
|
|
47
|
+
dt_hat_guess = get_supercomoving_time_interval(
|
|
48
|
+
z_start, 1.0 / a - 1.0, omega_matter, omega_lambda, little_h
|
|
49
|
+
)
|
|
50
|
+
error = dt_hat_guess - dt_hat
|
|
51
|
+
if jnp.abs(error) < tolerance:
|
|
52
|
+
break
|
|
53
|
+
# Use bisection method
|
|
54
|
+
if error > 0:
|
|
55
|
+
upper_bound = a # Move upper bound down
|
|
56
|
+
else:
|
|
57
|
+
lower_bound = a # Move lower bound up
|
|
58
|
+
a = (lower_bound + upper_bound) / 2 # Update scale factor to midpoint
|
|
59
|
+
return a
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_next_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
|
|
63
|
+
"""Advance scale factor by dt_hat using RK2 (midpoint) on da/dt_hat = a^2 * da/dt."""
|
|
64
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
65
|
+
H0 = 0.1 * little_h
|
|
66
|
+
|
|
67
|
+
def g(a):
|
|
68
|
+
a_dot = H0 * jnp.sqrt((omega_matter / a) + (omega_lambda * a**2))
|
|
69
|
+
return a**2 * a_dot # da/dt_hat
|
|
70
|
+
|
|
71
|
+
k1 = g(a_start)
|
|
72
|
+
k2 = g(a_start + 0.5 * dt_hat * k1)
|
|
73
|
+
a = a_start + dt_hat * k2
|
|
74
|
+
return a
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
3
|
+
|
|
4
|
+
# Pure functions for gravity calculations
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
|
|
8
|
+
Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
|
|
9
|
+
V = jnp.real(jd.fft.pifft3d(Vhat))
|
|
10
|
+
return V
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
2
3
|
|
|
3
4
|
# Pure functions for hydro simulation
|
|
4
5
|
|
|
@@ -89,11 +90,11 @@ def get_flux(rho_L, vx_L, vy_L, vz_L, rho_R, vx_R, vy_R, vz_R, cs):
|
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
def hydro_accelerate(vx, vy, vz, V, kx, ky, kz, dt):
|
|
92
|
-
V_hat =
|
|
93
|
+
V_hat = jd.fft.pfft3d(V)
|
|
93
94
|
|
|
94
|
-
ax = -jnp.real(
|
|
95
|
-
ay = -jnp.real(
|
|
96
|
-
az = -jnp.real(
|
|
95
|
+
ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
|
|
96
|
+
ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
|
|
97
|
+
az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
|
|
97
98
|
|
|
98
99
|
vx += ax * dt
|
|
99
100
|
vy += ay * dt
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
{
|
|
2
|
+
"physics": {
|
|
3
|
+
"quantum": {
|
|
4
|
+
"default": true,
|
|
5
|
+
"description": "switch on for fuzzy dark matter."
|
|
6
|
+
},
|
|
7
|
+
"gravity": {
|
|
8
|
+
"default": true,
|
|
9
|
+
"description": "switch on for self-gravity."
|
|
10
|
+
},
|
|
11
|
+
"hydro": {
|
|
12
|
+
"default": false,
|
|
13
|
+
"description": "switch on for hydrodynamics (isothermal)."
|
|
14
|
+
},
|
|
15
|
+
"particles": {
|
|
16
|
+
"default": false,
|
|
17
|
+
"description": "switch on for particle-mesh particles."
|
|
18
|
+
},
|
|
19
|
+
"cosmology": {
|
|
20
|
+
"default": false,
|
|
21
|
+
"description": "switch on for cosmological factors/comoving units."
|
|
22
|
+
},
|
|
23
|
+
"external_potential": {
|
|
24
|
+
"default": false,
|
|
25
|
+
"description": "switch on for external gravitational potential."
|
|
26
|
+
}
|
|
27
|
+
},
|
|
28
|
+
"domain": {
|
|
29
|
+
"box_size": {
|
|
30
|
+
"default": 10.0,
|
|
31
|
+
"description": "periodic domain box size [kpc]."
|
|
32
|
+
},
|
|
33
|
+
"resolution_base": {
|
|
34
|
+
"default": 32,
|
|
35
|
+
"description": "base resolution per linear dimension."
|
|
36
|
+
},
|
|
37
|
+
"resolution_multiplier": {
|
|
38
|
+
"default": 1,
|
|
39
|
+
"description": "resolution multiplier."
|
|
40
|
+
}
|
|
41
|
+
},
|
|
42
|
+
"time": {
|
|
43
|
+
"start": {
|
|
44
|
+
"default": 0.0,
|
|
45
|
+
"description": "simulation start time [kpc/(km/s)] or [redshift] (cosmology=true)."
|
|
46
|
+
},
|
|
47
|
+
"end": {
|
|
48
|
+
"default": 1.0,
|
|
49
|
+
"description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
|
|
50
|
+
},
|
|
51
|
+
"adaptive": {
|
|
52
|
+
"default": false,
|
|
53
|
+
"description": "switch on for adaptive time stepping."
|
|
54
|
+
}
|
|
55
|
+
},
|
|
56
|
+
"output": {
|
|
57
|
+
"path": {
|
|
58
|
+
"default": "./checkpoints",
|
|
59
|
+
"description": "path to output directory."
|
|
60
|
+
},
|
|
61
|
+
"num_checkpoints": {
|
|
62
|
+
"default": 100,
|
|
63
|
+
"description": "number of checkpoints to save."
|
|
64
|
+
},
|
|
65
|
+
"save": {
|
|
66
|
+
"default": true,
|
|
67
|
+
"description": "switch on to save checkpoints."
|
|
68
|
+
},
|
|
69
|
+
"plot_dynamic_range": {
|
|
70
|
+
"default": 100.0,
|
|
71
|
+
"description": "dynamic range for plotting."
|
|
72
|
+
}
|
|
73
|
+
},
|
|
74
|
+
"quantum": {
|
|
75
|
+
"m_22": {
|
|
76
|
+
"default": 1.0,
|
|
77
|
+
"description": "axion mass [10^{-22} eV]."
|
|
78
|
+
}
|
|
79
|
+
},
|
|
80
|
+
"hydro": {
|
|
81
|
+
"sound_speed": {
|
|
82
|
+
"default": 1.0,
|
|
83
|
+
"description": "isothermal sound speed [km/s]."
|
|
84
|
+
}
|
|
85
|
+
},
|
|
86
|
+
"particles": {
|
|
87
|
+
"num_particles": {
|
|
88
|
+
"default": 0,
|
|
89
|
+
"description": "number of particles."
|
|
90
|
+
},
|
|
91
|
+
"particle_mass": {
|
|
92
|
+
"default": 1.0,
|
|
93
|
+
"description": "particle mass [M_sun]."
|
|
94
|
+
}
|
|
95
|
+
},
|
|
96
|
+
"cosmology": {
|
|
97
|
+
"omega_matter": {
|
|
98
|
+
"default": 0.3,
|
|
99
|
+
"description": "matter density parameter."
|
|
100
|
+
},
|
|
101
|
+
"omega_lambda": {
|
|
102
|
+
"default": 0.7,
|
|
103
|
+
"description": "dark energy density parameter."
|
|
104
|
+
},
|
|
105
|
+
"little_h": {
|
|
106
|
+
"default": 0.7,
|
|
107
|
+
"description": "Hubble parameter little h (H_0 = 100 h km/s/Mpc)."
|
|
108
|
+
}
|
|
109
|
+
},
|
|
110
|
+
"version": {
|
|
111
|
+
"default": "unknown",
|
|
112
|
+
"description": "jaxion version used (auto detected)."
|
|
113
|
+
}
|
|
114
|
+
}
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
|
+
import jaxdecomp as jd
|
|
3
4
|
|
|
4
5
|
# Pure functions for particle-mesh calculations (stars, BHs, ...)
|
|
5
6
|
|
|
@@ -64,10 +65,10 @@ def get_acceleration(pos, V, kx, ky, kz, dx):
|
|
|
64
65
|
i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
|
|
65
66
|
|
|
66
67
|
# find accelerations on the grid
|
|
67
|
-
V_hat =
|
|
68
|
-
ax = -jnp.real(
|
|
69
|
-
ay = -jnp.real(
|
|
70
|
-
az = -jnp.real(
|
|
68
|
+
V_hat = jd.fft.pfft3d(V)
|
|
69
|
+
ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
|
|
70
|
+
ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
|
|
71
|
+
az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
|
|
71
72
|
a_grid = jnp.stack((ax, ay, az), axis=-1)
|
|
72
73
|
|
|
73
74
|
# interpolate the accelerations to the particle positions
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
3
|
+
|
|
4
|
+
# Pure functions for quantum simulation
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def quantum_kick(psi, V, m_per_hbar, dt):
|
|
8
|
+
psi = jnp.exp(-1.0j * m_per_hbar * dt * V) * psi
|
|
9
|
+
return psi
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def quantum_drift(psi, k_sq, m_per_hbar, dt):
|
|
13
|
+
psi_hat = jd.fft.pfft3d(psi)
|
|
14
|
+
psi_hat = jnp.exp(dt * (-1.0j * k_sq / m_per_hbar / 2.0)) * psi_hat
|
|
15
|
+
psi = jd.fft.pifft3d(psi_hat)
|
|
16
|
+
return psi
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_gradient(psi, kx, ky, kz):
|
|
20
|
+
"""
|
|
21
|
+
Computes gradient of wavefunction psi
|
|
22
|
+
psi: jnp.ndarray (3D)
|
|
23
|
+
Returns: grad_psi_x, grad_psi_y, grad_psi_z
|
|
24
|
+
"""
|
|
25
|
+
psi_hat = jd.fft.pfft3d(psi)
|
|
26
|
+
grad_psi_x_hat = 1.0j * kx * psi_hat
|
|
27
|
+
grad_psi_y_hat = 1.0j * ky * psi_hat
|
|
28
|
+
grad_psi_z_hat = 1.0j * kz * psi_hat
|
|
29
|
+
|
|
30
|
+
grad_psi_x = jd.fft.pifft3d(grad_psi_x_hat)
|
|
31
|
+
grad_psi_y = jd.fft.pifft3d(grad_psi_y_hat)
|
|
32
|
+
grad_psi_z = jd.fft.pifft3d(grad_psi_z_hat)
|
|
33
|
+
|
|
34
|
+
return grad_psi_x, grad_psi_y, grad_psi_z
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def quantum_velocity(psi, box_size, m_per_hbar):
|
|
38
|
+
"""
|
|
39
|
+
Compute the velocity from the wave-function
|
|
40
|
+
v = nabla S / m
|
|
41
|
+
psi = sqrt(rho) exp(i S / hbar)
|
|
42
|
+
"""
|
|
43
|
+
N = psi.shape[0]
|
|
44
|
+
dx = box_size / N
|
|
45
|
+
|
|
46
|
+
S_per_hbar = jnp.angle(psi)
|
|
47
|
+
|
|
48
|
+
# Central differences with phase unwrapping
|
|
49
|
+
# vx = jnp.roll(S_per_hbar, -1, axis=0) - jnp.roll(S_per_hbar, 1, axis=0)
|
|
50
|
+
# vy = jnp.roll(S_per_hbar, -1, axis=1) - jnp.roll(S_per_hbar, 1, axis=1)
|
|
51
|
+
# vz = jnp.roll(S_per_hbar, -1, axis=2) - jnp.roll(S_per_hbar, 1, axis=2)
|
|
52
|
+
# vx = jnp.where(vx > jnp.pi, vx - 2 * jnp.pi, vx)
|
|
53
|
+
# vx = jnp.where(vx <= -jnp.pi, vx + 2 * jnp.pi, vx)
|
|
54
|
+
# vy = jnp.where(vy > jnp.pi, vy - 2 * jnp.pi, vy)
|
|
55
|
+
# vy = jnp.where(vy <= -jnp.pi, vy + 2 * jnp.pi, vy)
|
|
56
|
+
# vz = jnp.where(vz > jnp.pi, vz - 2 * jnp.pi, vz)
|
|
57
|
+
# vz = jnp.where(vz <= -jnp.pi, vz + 2 * jnp.pi, vz)
|
|
58
|
+
# vx = vx / (2.0 * dx) / m_per_hbar
|
|
59
|
+
# vy = vy / (2.0 * dx) / m_per_hbar
|
|
60
|
+
# vz = vz / (2.0 * dx) / m_per_hbar
|
|
61
|
+
|
|
62
|
+
# Forward differences
|
|
63
|
+
vx1 = jnp.roll(S_per_hbar, -1, axis=0) - S_per_hbar
|
|
64
|
+
vy1 = jnp.roll(S_per_hbar, -1, axis=1) - S_per_hbar
|
|
65
|
+
vz1 = jnp.roll(S_per_hbar, -1, axis=2) - S_per_hbar
|
|
66
|
+
vx1 = jnp.where(vx1 > jnp.pi, vx1 - 2 * jnp.pi, vx1)
|
|
67
|
+
vx1 = jnp.where(vx1 <= -jnp.pi, vx1 + 2 * jnp.pi, vx1)
|
|
68
|
+
vy1 = jnp.where(vy1 > jnp.pi, vy1 - 2 * jnp.pi, vy1)
|
|
69
|
+
vy1 = jnp.where(vy1 <= -jnp.pi, vy1 + 2 * jnp.pi, vy1)
|
|
70
|
+
vz1 = jnp.where(vz1 > jnp.pi, vz1 - 2 * jnp.pi, vz1)
|
|
71
|
+
vz1 = jnp.where(vz1 <= -jnp.pi, vz1 + 2 * jnp.pi, vz1)
|
|
72
|
+
vx1 = vx1 / dx / m_per_hbar
|
|
73
|
+
vy1 = vy1 / dx / m_per_hbar
|
|
74
|
+
vz1 = vz1 / dx / m_per_hbar
|
|
75
|
+
|
|
76
|
+
# Backward differences
|
|
77
|
+
vx2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=0)
|
|
78
|
+
vy2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=1)
|
|
79
|
+
vz2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=2)
|
|
80
|
+
vx2 = jnp.where(vx2 > jnp.pi, vx2 - 2 * jnp.pi, vx2)
|
|
81
|
+
vx2 = jnp.where(vx2 <= -jnp.pi, vx2 + 2 * jnp.pi, vx2)
|
|
82
|
+
vy2 = jnp.where(vy2 > jnp.pi, vy2 - 2 * jnp.pi, vy2)
|
|
83
|
+
vy2 = jnp.where(vy2 <= -jnp.pi, vy2 + 2 * jnp.pi, vy2)
|
|
84
|
+
vz2 = jnp.where(vz2 > jnp.pi, vz2 - 2 * jnp.pi, vz2)
|
|
85
|
+
vz2 = jnp.where(vz2 <= -jnp.pi, vz2 + 2 * jnp.pi, vz2)
|
|
86
|
+
vx2 = vx2 / dx / m_per_hbar
|
|
87
|
+
vy2 = vy2 / dx / m_per_hbar
|
|
88
|
+
vz2 = vz2 / dx / m_per_hbar
|
|
89
|
+
|
|
90
|
+
# Average forward and backward
|
|
91
|
+
vx = 0.5 * (vx1 + vx2)
|
|
92
|
+
vy = 0.5 * (vy1 + vy2)
|
|
93
|
+
vz = 0.5 * (vz1 + vz2)
|
|
94
|
+
|
|
95
|
+
return vx, vy, vz
|