jaxion 0.0.4__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {jaxion-0.0.4 → jaxion-0.0.6}/PKG-INFO +41 -13
  2. {jaxion-0.0.4 → jaxion-0.0.6}/README.md +40 -12
  3. jaxion-0.0.4/jaxion/params_default.json → jaxion-0.0.6/jaxion/defaults.json +4 -0
  4. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/gravity.py +2 -2
  5. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/simulation.py +122 -46
  6. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/utils.py +18 -1
  7. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/visualization.py +10 -6
  8. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion.egg-info/PKG-INFO +41 -13
  9. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion.egg-info/SOURCES.txt +3 -1
  10. {jaxion-0.0.4 → jaxion-0.0.6}/pyproject.toml +1 -1
  11. jaxion-0.0.6/tests/test_analysis.py +24 -0
  12. jaxion-0.0.6/tests/test_cosmology.py +29 -0
  13. {jaxion-0.0.4 → jaxion-0.0.6}/tests/test_examples.py +13 -2
  14. {jaxion-0.0.4 → jaxion-0.0.6}/LICENSE +0 -0
  15. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/__init__.py +0 -0
  16. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/analysis.py +0 -0
  17. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/constants.py +0 -0
  18. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/cosmology.py +0 -0
  19. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/hydro.py +0 -0
  20. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/particles.py +0 -0
  21. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion/quantum.py +0 -0
  22. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion.egg-info/dependency_links.txt +0 -0
  23. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion.egg-info/requires.txt +0 -0
  24. {jaxion-0.0.4 → jaxion-0.0.6}/jaxion.egg-info/top_level.txt +0 -0
  25. {jaxion-0.0.4 → jaxion-0.0.6}/requirements.txt +0 -0
  26. {jaxion-0.0.4 → jaxion-0.0.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -33,8 +33,11 @@ Dynamic: license-file
33
33
  [![PyPI Version Status][pypi-badge]][pypi-link]
34
34
  [![Test Status][workflow-test-badge]][workflow-test-link]
35
35
  [![Coverage][coverage-badge]][coverage-link]
36
+ [![Ruff][ruff-badge]][ruff-link]
37
+ [![asv][asv-badge]][asv-link]
36
38
  [![Readthedocs Status][docs-badge]][docs-link]
37
39
  [![License][license-badge]][license-link]
40
+ [![Software DOI][software-doi-badge]][software-doi-link]
38
41
 
39
42
  [status-link]: https://www.repostatus.org/#active
40
43
  [status-badge]: https://www.repostatus.org/badges/latest/active.svg
@@ -44,10 +47,17 @@ Dynamic: license-file
44
47
  [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
45
48
  [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
46
49
  [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
50
+ [ruff-link]: https://github.com/astral-sh/ruff
51
+ [ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
52
+ [asv-link]: https://jaxionproject.github.io/jaxion-benchmarks/
53
+ [asv-badge]: https://img.shields.io/badge/benchmarked%20by-asv-blue.svg?style=flat
47
54
  [docs-link]: https://jaxion.readthedocs.io
48
55
  [docs-badge]: https://readthedocs.org/projects/jaxion/badge
49
56
  [license-link]: https://opensource.org/licenses/Apache-2.0
50
57
  [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
58
+ [software-doi-link]: https://doi.org/10.5281/zenodo.17438467
59
+ [software-doi-badge]: https://zenodo.org/badge/1072645376.svg
60
+
51
61
 
52
62
  A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
53
63
 
@@ -58,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
58
68
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
59
69
 
60
70
 
61
- ## Getting started
71
+ ## Install Jaxion
62
72
 
63
73
  Install with:
64
74
 
@@ -77,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
77
87
 
78
88
  ## Examples
79
89
 
80
- Check out the `examples/` directory for demonstrations of using Jaxion.
90
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
81
91
 
82
92
  <p align="center">
83
93
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -96,8 +106,8 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
96
106
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
97
107
  <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
98
108
  </a>
99
- <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
100
- <img src="examples/logo/movie.gif" alt="logo" width="128"/>
109
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo_inverse_problem">
110
+ <img src="examples/logo_inverse_problem/movie.gif" alt="logo_inverse_problem" width="128"/>
101
111
  </a>
102
112
  <br>
103
113
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
@@ -112,22 +122,40 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
112
122
  </p>
113
123
 
114
124
 
115
- ## Links
125
+ ## High-Performance
116
126
 
117
- * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
118
- * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
127
+ Jaxion is scalable on multiple GPUs!
119
128
 
129
+ <p align="center">
130
+ <a href="https://jaxion.readthedocs.io">
131
+ <img src="examples/soliton_binary_merger/timing.png" alt="timing" width="400"/>
132
+ </a>
133
+ </p>
120
134
 
121
- ## Testing
122
135
 
123
- Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
136
+ ## Contributing
124
137
 
138
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
125
139
 
126
- ## Contributing
127
140
 
128
- Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
141
+ ## Links
142
+
143
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
144
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using Jaxion.
129
145
 
130
146
 
131
147
  ## Cite this repository
132
148
 
133
- TODO XXX
149
+ If you use this software, please cite it as below.
150
+
151
+ ```bibtex
152
+ @software{Mocz_Jaxion_2025,
153
+ author = {Mocz, Philip},
154
+ doi = {10.5281/zenodo.17438467},
155
+ month = oct,
156
+ title = {{Jaxion}},
157
+ url = {https://github.com/JaxionProject/jaxion},
158
+ version = {0.0.4},
159
+ year = {2025}
160
+ }
161
+ ```
@@ -10,8 +10,11 @@
10
10
  [![PyPI Version Status][pypi-badge]][pypi-link]
11
11
  [![Test Status][workflow-test-badge]][workflow-test-link]
12
12
  [![Coverage][coverage-badge]][coverage-link]
13
+ [![Ruff][ruff-badge]][ruff-link]
14
+ [![asv][asv-badge]][asv-link]
13
15
  [![Readthedocs Status][docs-badge]][docs-link]
14
16
  [![License][license-badge]][license-link]
17
+ [![Software DOI][software-doi-badge]][software-doi-link]
15
18
 
16
19
  [status-link]: https://www.repostatus.org/#active
17
20
  [status-badge]: https://www.repostatus.org/badges/latest/active.svg
@@ -21,10 +24,17 @@
21
24
  [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
22
25
  [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
23
26
  [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
27
+ [ruff-link]: https://github.com/astral-sh/ruff
28
+ [ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
29
+ [asv-link]: https://jaxionproject.github.io/jaxion-benchmarks/
30
+ [asv-badge]: https://img.shields.io/badge/benchmarked%20by-asv-blue.svg?style=flat
24
31
  [docs-link]: https://jaxion.readthedocs.io
25
32
  [docs-badge]: https://readthedocs.org/projects/jaxion/badge
26
33
  [license-link]: https://opensource.org/licenses/Apache-2.0
27
34
  [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
35
+ [software-doi-link]: https://doi.org/10.5281/zenodo.17438467
36
+ [software-doi-badge]: https://zenodo.org/badge/1072645376.svg
37
+
28
38
 
29
39
  A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
30
40
 
@@ -35,7 +45,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
35
45
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
36
46
 
37
47
 
38
- ## Getting started
48
+ ## Install Jaxion
39
49
 
40
50
  Install with:
41
51
 
@@ -54,7 +64,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
54
64
 
55
65
  ## Examples
56
66
 
57
- Check out the `examples/` directory for demonstrations of using Jaxion.
67
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
58
68
 
59
69
  <p align="center">
60
70
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -73,8 +83,8 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
73
83
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
74
84
  <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
75
85
  </a>
76
- <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
77
- <img src="examples/logo/movie.gif" alt="logo" width="128"/>
86
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo_inverse_problem">
87
+ <img src="examples/logo_inverse_problem/movie.gif" alt="logo_inverse_problem" width="128"/>
78
88
  </a>
79
89
  <br>
80
90
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
@@ -89,22 +99,40 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
89
99
  </p>
90
100
 
91
101
 
92
- ## Links
102
+ ## High-Performance
93
103
 
94
- * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
95
- * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
104
+ Jaxion is scalable on multiple GPUs!
96
105
 
106
+ <p align="center">
107
+ <a href="https://jaxion.readthedocs.io">
108
+ <img src="examples/soliton_binary_merger/timing.png" alt="timing" width="400"/>
109
+ </a>
110
+ </p>
97
111
 
98
- ## Testing
99
112
 
100
- Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
113
+ ## Contributing
101
114
 
115
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
102
116
 
103
- ## Contributing
104
117
 
105
- Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
118
+ ## Links
119
+
120
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
121
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using Jaxion.
106
122
 
107
123
 
108
124
  ## Cite this repository
109
125
 
110
- TODO XXX
126
+ If you use this software, please cite it as below.
127
+
128
+ ```bibtex
129
+ @software{Mocz_Jaxion_2025,
130
+ author = {Mocz, Philip},
131
+ doi = {10.5281/zenodo.17438467},
132
+ month = oct,
133
+ title = {{Jaxion}},
134
+ url = {https://github.com/JaxionProject/jaxion},
135
+ version = {0.0.4},
136
+ year = {2025}
137
+ }
138
+ ```
@@ -75,6 +75,10 @@
75
75
  "m_22": {
76
76
  "default": 1.0,
77
77
  "description": "axion mass [10^{-22} eV]."
78
+ },
79
+ "f_15": {
80
+ "default": 0.0,
81
+ "description": "self-interaction strength [10^{15} GeV]."
78
82
  }
79
83
  },
80
84
  "hydro": {
@@ -5,6 +5,6 @@ import jaxdecomp as jd
5
5
 
6
6
 
7
7
  def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
8
- Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
- V = jnp.real(jd.fft.pifft3d(Vhat))
8
+ V_hat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
+ V = jnp.real(jd.fft.pifft3d(V_hat))
10
10
  return V
@@ -14,6 +14,7 @@ from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
14
14
  from .utils import (
15
15
  set_up_parameters,
16
16
  print_parameters,
17
+ print_distributed_info,
17
18
  xmeshgrid,
18
19
  xmeshgrid_transpose,
19
20
  xzeros,
@@ -61,9 +62,10 @@ class Simulation:
61
62
  self.params["physics"]["hydro"]
62
63
  or self.params["physics"]["particles"]
63
64
  or self.params["physics"]["external_potential"]
65
+ or self.params["quantum"]["f_15"] != 0.0
64
66
  ):
65
67
  raise NotImplementedError(
66
- "Cosmological hydro/particles/external_potential physics is not yet implemented."
68
+ "Cosmological hydro/particles/external_potential/SI is not yet implemented."
67
69
  )
68
70
 
69
71
  if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
@@ -72,10 +74,12 @@ class Simulation:
72
74
  "hydro/particles sharding is not yet implemented."
73
75
  )
74
76
 
75
- if self.params["output"]["save"]:
76
- if jax.process_index() == 0:
77
- print("Simulation parameters:")
78
- print_parameters(self.params)
77
+ # print info
78
+ if jax.process_index() == 0:
79
+ print("Simulation parameters:")
80
+ print_parameters(self.params)
81
+ if sharding is not None:
82
+ print_distributed_info()
79
83
 
80
84
  # jitted functions
81
85
  self.xmeshgrid_jit = jax.jit(
@@ -91,6 +95,12 @@ class Simulation:
91
95
  xones, static_argnums=0, in_shardings=None, out_shardings=sharding
92
96
  )
93
97
 
98
+ # customfunctions
99
+ self.custom_kick = None
100
+ self.custom_drift = None
101
+ self.custom_density = None
102
+ self.custom_plot = None
103
+
94
104
  # simulation state
95
105
  self.state = {}
96
106
  self.state["t"] = 0.0
@@ -170,6 +180,30 @@ class Simulation:
170
180
  / constants["speed_of_light"] ** 2
171
181
  )
172
182
 
183
+ @property
184
+ def scattering_length(self):
185
+ """
186
+ Return the axion self-interaction scattering length in the simulation (kpc)
187
+ """
188
+ f_15 = self.params["quantum"]["f_15"]
189
+ if f_15 == 0.0:
190
+ return 0.0
191
+ else:
192
+ f = f_15 * 1.0e24 * constants["electron_volt"]
193
+ sign = 1.0 if f > 0 else -1.0
194
+ hbar = constants["reduced_planck_constant"]
195
+ c = constants["speed_of_light"]
196
+ m = self.axion_mass
197
+ a_s = sign * (hbar * c**3 * m) / (32.0 * jnp.pi * (f**2))
198
+ return a_s
199
+
200
+ @property
201
+ def m_per_hbar(self):
202
+ """
203
+ Return the mass per hbar in the simulation (M_sun / hbar)
204
+ """
205
+ return self.axion_mass / constants["reduced_planck_constant"]
206
+
173
207
  @property
174
208
  def sound_speed(self):
175
209
  """
@@ -212,8 +246,14 @@ class Simulation:
212
246
  """
213
247
  Return the dark matter velocity field from the wavefunction
214
248
  """
215
- m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
216
- return quantum_velocity(self.state["psi"], self.box_size, m_per_hbar)
249
+ return quantum_velocity(self.state["psi"], self.box_size, self.m_per_hbar)
250
+
251
+ @property
252
+ def rho_bar(self):
253
+ """
254
+ Return the mean density of the simulation
255
+ """
256
+ return self._calc_rho_bar(self.state)
217
257
 
218
258
  def _calc_rho_bar(self, state):
219
259
  rho_bar = 0.0
@@ -226,6 +266,8 @@ class Simulation:
226
266
  n_particles = self.num_particles
227
267
  box_size = self.box_size
228
268
  rho_bar += m_particle * n_particles / box_size
269
+ if self.custom_density is not None:
270
+ rho_bar += jnp.mean(self.custom_density(state))
229
271
  return rho_bar
230
272
 
231
273
  def _calc_grav_potential(self, state, k_sq):
@@ -239,6 +281,8 @@ class Simulation:
239
281
  rho_tot += state["rho"]
240
282
  if self.params["physics"]["particles"]:
241
283
  rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
284
+ if self.custom_density is not None:
285
+ rho_tot += self.custom_density(state)
242
286
  if self.params["physics"]["cosmology"]:
243
287
  scale_factor = 1.0 / (1.0 + state["redshift"])
244
288
  rho_bar *= scale_factor
@@ -271,8 +315,9 @@ class Simulation:
271
315
 
272
316
  # Simulation parameters
273
317
  dx = self.dx
274
- m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
275
318
  box_size = self.box_size
319
+ num_cells = self.resolution**3
320
+ m_per_hbar = self.m_per_hbar
276
321
 
277
322
  dt_fac = 1.0
278
323
  dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
@@ -281,8 +326,20 @@ class Simulation:
281
326
  t_span = t_end - t_start
282
327
  state["t"] = t_start
283
328
 
329
+ use_quantum = self.params["physics"]["quantum"]
330
+ use_gravity = self.params["physics"]["gravity"]
331
+ use_hydro = self.params["physics"]["hydro"]
332
+ use_particles = self.params["physics"]["particles"]
333
+ use_cosmology = self.params["physics"]["cosmology"]
334
+ use_external_potential = self.params["physics"]["external_potential"]
335
+ save = self.params["output"]["save"]
336
+ use_custom = self.custom_kick is not None or self.custom_drift is not None
337
+ if use_custom:
338
+ custom_kick = self.custom_kick
339
+ custom_drift = self.custom_drift
340
+
284
341
  # cosmology
285
- if self.params["physics"]["cosmology"]:
342
+ if use_cosmology:
286
343
  z_start = self.params["time"]["start"]
287
344
  z_end = self.params["time"]["end"]
288
345
  omega_matter = self.params["cosmology"]["omega_matter"]
@@ -294,6 +351,17 @@ class Simulation:
294
351
  state["t"] = 0.0
295
352
  state["redshift"] = z_start
296
353
 
354
+ # self-interaction
355
+ a_s = self.scattering_length
356
+ c = constants["speed_of_light"]
357
+ m = self.axion_mass
358
+ si_coeff = None
359
+ si_coeff2 = None
360
+ do_self_interaction = a_s != 0.0
361
+ if do_self_interaction:
362
+ si_coeff = (4.0 * jnp.pi) * (a_s / m) / m_per_hbar**2
363
+ si_coeff2 = (32.0 * jnp.pi**2 / 3.0) * (a_s / (m * c)) ** 2 / m_per_hbar**5
364
+
297
365
  # hydro
298
366
  c_sound = self.params["hydro"]["sound_speed"]
299
367
 
@@ -308,16 +376,14 @@ class Simulation:
308
376
  k_sq = None
309
377
 
310
378
  # Fourier space variables
311
- if self.params["physics"]["gravity"] or self.params["physics"]["quantum"]:
379
+ if use_gravity or use_quantum:
312
380
  kx, ky, kz = self.kgrid
313
381
  k_sq = kx**2 + ky**2 + kz**2
314
382
 
315
383
  # Checkpointer
316
- if self.params["output"]["save"]:
384
+ if save:
317
385
  options = ocp.CheckpointManagerOptions()
318
- checkpoint_dir = checkpoint_dir = os.path.join(
319
- os.getcwd(), self.params["output"]["path"]
320
- )
386
+ checkpoint_dir = os.path.join(os.getcwd(), self.params["output"]["path"])
321
387
  path = os.path.join(os.getcwd(), checkpoint_dir)
322
388
  if jax.process_index() == 0:
323
389
  path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
@@ -327,80 +393,86 @@ class Simulation:
327
393
 
328
394
  def _kick(state, kx, ky, kz, k_sq, dt):
329
395
  # Kick (half-step)
330
- if (
331
- self.params["physics"]["gravity"]
332
- and self.params["physics"]["external_potential"]
333
- ):
396
+ if use_gravity and use_external_potential:
334
397
  V = self._calc_grav_potential(state, k_sq) + state["V_ext"]
335
- elif self.params["physics"]["gravity"]:
398
+ elif use_gravity:
336
399
  V = self._calc_grav_potential(state, k_sq)
337
- elif self.params["physics"]["external_potential"]:
400
+ elif use_external_potential:
338
401
  V = state["V_ext"]
339
402
 
340
- if (
341
- self.params["physics"]["gravity"]
342
- or self.params["physics"]["external_potential"]
343
- ):
344
- if self.params["physics"]["quantum"]:
345
- state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
346
- if self.params["physics"]["hydro"]:
403
+ if use_gravity or use_external_potential:
404
+ if use_quantum:
405
+ if do_self_interaction:
406
+ rho = jnp.abs(state["psi"]) ** 2
407
+ V_prime = V + si_coeff * rho + si_coeff2 * rho**2
408
+ state["psi"] = quantum_kick(
409
+ state["psi"], V_prime, m_per_hbar, dt
410
+ )
411
+ else:
412
+ state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
413
+ if use_hydro:
347
414
  state["vx"], state["vy"], state["vz"] = hydro_accelerate(
348
415
  state["vx"], state["vy"], state["vz"], V, kx, ky, kz, dt
349
416
  )
350
- if self.params["physics"]["particles"]:
417
+ if use_particles:
351
418
  state["vel"] = particles_accelerate(
352
419
  state["vel"], state["pos"], V, kx, ky, kz, dx, dt
353
420
  )
421
+ if use_custom:
422
+ state = custom_kick(state, V, dt)
423
+
424
+ return state
354
425
 
355
426
  def _drift(state, k_sq, dt):
356
427
  # Drift (full-step)
357
- if self.params["physics"]["quantum"]:
428
+ if use_quantum:
358
429
  state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
359
- if self.params["physics"]["hydro"]:
430
+ if use_hydro:
360
431
  state["rho"], state["vx"], state["vy"], state["vz"] = hydro_fluxes(
361
432
  state["rho"], state["vx"], state["vy"], state["vz"], dt, dx, c_sound
362
433
  )
363
- if self.params["physics"]["particles"]:
434
+ if use_particles:
364
435
  state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
436
+ if use_custom:
437
+ state = custom_drift(state, k_sq, dt)
438
+
439
+ return state
365
440
 
366
- @jax.jit
367
441
  def _update(_, carry):
368
442
  # Update the simulation state by one timestep
369
443
  # according to a 2nd-order `kick-drift-kick` scheme
370
444
  state, kx, ky, kz, k_sq = carry
371
- _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
372
- _drift(state, k_sq, dt)
445
+ state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
446
+ state = _drift(state, k_sq, dt)
373
447
  # update time & redshift
374
448
  state["t"] += dt
375
- if self.params["physics"]["cosmology"]:
449
+ if use_cosmology:
376
450
  scale_factor = get_next_scale_factor(
377
- state["redshift"],
378
- dt,
379
- self.params["cosmology"]["omega_matter"],
380
- self.params["cosmology"]["omega_lambda"],
381
- self.params["cosmology"]["little_h"],
451
+ state["redshift"], dt, omega_matter, omega_lambda, little_h
382
452
  )
383
453
  state["redshift"] = 1.0 / scale_factor - 1.0
384
- _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
454
+ state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
385
455
 
386
456
  return state, kx, ky, kz, k_sq
387
457
 
388
458
  # save initial state
389
459
  if jax.process_index() == 0:
390
460
  print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
391
- if self.params["output"]["save"]:
461
+ if save:
392
462
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
393
463
  json.dump(self.params, f, indent=2)
394
464
  async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
395
465
  plot_sim(state, checkpoint_dir, 0, self.params)
466
+ if self.custom_plot is not None:
467
+ self.custom_plot(state, checkpoint_dir, 0, self.params)
396
468
  async_checkpoint_manager.wait_until_finished()
397
469
 
398
470
  # Simulation Main Loop
399
471
  t_start_timer = time.time()
400
- if self.params["output"]["save"]:
472
+ if save:
401
473
  for i in range(1, num_checkpoints + 1):
402
474
  carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
403
- state, kx, ky, kz, k_sq = carry
475
+ state, _, _, _, _ = carry
404
476
  jax.block_until_ready(state)
405
477
  # save state
406
478
  async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
@@ -408,14 +480,18 @@ class Simulation:
408
480
  elapsed = time.time() - t_start_timer
409
481
  est_total = elapsed / i * num_checkpoints
410
482
  est_remaining = est_total - elapsed
483
+ mcups = (num_cells * (i * nt_sub)) / (elapsed * 1.0e6)
411
484
  if jax.process_index() == 0:
412
485
  print(
413
- f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
486
+ f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
414
487
  )
415
488
  plot_sim(state, checkpoint_dir, i, self.params)
489
+ if self.custom_plot is not None:
490
+ self.custom_plot(state, checkpoint_dir, i, self.params)
416
491
  async_checkpoint_manager.wait_until_finished()
417
492
  else:
418
- state = jax.lax.fori_loop(0, nt, _update, init_val=state)
493
+ carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
494
+ state, _, _, _, _ = carry
419
495
  jax.block_until_ready(state)
420
496
  if jax.process_index() == 0:
421
497
  print("Simulation Run Time (s): ", time.time() - t_start_timer)
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  import importlib.resources
6
6
  import json
7
7
  from importlib.metadata import version
8
+ import jax
8
9
  import jax.numpy as jnp
9
10
 
10
11
 
@@ -12,9 +13,25 @@ def print_parameters(params):
12
13
  print(json.dumps(params, indent=2))
13
14
 
14
15
 
16
+ def print_distributed_info():
17
+ for env_var in [
18
+ "SLURM_JOB_ID",
19
+ "SLURM_NTASKS",
20
+ "SLURM_NODELIST",
21
+ "SLURM_STEP_NODELIST",
22
+ "SLURM_STEP_GPUS",
23
+ "SLURM_GPUS",
24
+ ]:
25
+ print(f"{env_var}: {os.getenv(env_var, '')}")
26
+ print("Total number of processes: ", jax.process_count())
27
+ print("Total number of devices: ", jax.device_count())
28
+ print("List of devices: ", jax.devices())
29
+ print("Number of devices on this process: ", jax.local_device_count())
30
+
31
+
15
32
  def set_up_parameters(user_overwrites):
16
33
  # first load the default params
17
- params_path = importlib.resources.files("jaxion") / "params_default.json"
34
+ params_path = importlib.resources.files("jaxion") / "defaults.json"
18
35
  with params_path.open("r", encoding="utf-8") as f:
19
36
  params = json.load(f)
20
37
 
@@ -13,15 +13,19 @@ def plot_sim(state, checkpoint_dir, i, params):
13
13
  if params["physics"]["quantum"]:
14
14
  nx = state["psi"].shape[0]
15
15
  rho_bar_dm = jnp.mean(jnp.abs(state["psi"]) ** 2)
16
- rho_proj_dm = jax.experimental.multihost_utils.process_allgather(
17
- jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2).T)
18
- ).reshape(nx, nx)
16
+ rho_proj_dm = jnp.log10(
17
+ jax.experimental.multihost_utils.process_allgather(
18
+ jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2), tiled=True
19
+ )
20
+ ).T
19
21
  if params["physics"]["hydro"]:
20
22
  nx = state["rho"].shape[0]
21
23
  rho_bar_gas = jnp.mean(state["rho"])
22
- rho_proj_gas = jax.experimental.multihost_utils.process_allgather(
23
- jnp.log10(jnp.mean(state["rho"], axis=2).T)
24
- ).reshape(nx, nx)
24
+ rho_proj_gas = jnp.log10(
25
+ jax.experimental.multihost_utils.process_allgather(
26
+ jnp.mean(state["rho"], axis=2), tiled=True
27
+ )
28
+ ).T
25
29
 
26
30
  # create plot on process 0
27
31
  if jax.process_index() == 0:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -33,8 +33,11 @@ Dynamic: license-file
33
33
  [![PyPI Version Status][pypi-badge]][pypi-link]
34
34
  [![Test Status][workflow-test-badge]][workflow-test-link]
35
35
  [![Coverage][coverage-badge]][coverage-link]
36
+ [![Ruff][ruff-badge]][ruff-link]
37
+ [![asv][asv-badge]][asv-link]
36
38
  [![Readthedocs Status][docs-badge]][docs-link]
37
39
  [![License][license-badge]][license-link]
40
+ [![Software DOI][software-doi-badge]][software-doi-link]
38
41
 
39
42
  [status-link]: https://www.repostatus.org/#active
40
43
  [status-badge]: https://www.repostatus.org/badges/latest/active.svg
@@ -44,10 +47,17 @@ Dynamic: license-file
44
47
  [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
45
48
  [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
46
49
  [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
50
+ [ruff-link]: https://github.com/astral-sh/ruff
51
+ [ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
52
+ [asv-link]: https://jaxionproject.github.io/jaxion-benchmarks/
53
+ [asv-badge]: https://img.shields.io/badge/benchmarked%20by-asv-blue.svg?style=flat
47
54
  [docs-link]: https://jaxion.readthedocs.io
48
55
  [docs-badge]: https://readthedocs.org/projects/jaxion/badge
49
56
  [license-link]: https://opensource.org/licenses/Apache-2.0
50
57
  [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
58
+ [software-doi-link]: https://doi.org/10.5281/zenodo.17438467
59
+ [software-doi-badge]: https://zenodo.org/badge/1072645376.svg
60
+
51
61
 
52
62
  A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
53
63
 
@@ -58,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
58
68
  Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
59
69
 
60
70
 
61
- ## Getting started
71
+ ## Install Jaxion
62
72
 
63
73
  Install with:
64
74
 
@@ -77,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
77
87
 
78
88
  ## Examples
79
89
 
80
- Check out the `examples/` directory for demonstrations of using Jaxion.
90
+ Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
81
91
 
82
92
  <p align="center">
83
93
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
@@ -96,8 +106,8 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
96
106
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
97
107
  <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
98
108
  </a>
99
- <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
100
- <img src="examples/logo/movie.gif" alt="logo" width="128"/>
109
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo_inverse_problem">
110
+ <img src="examples/logo_inverse_problem/movie.gif" alt="logo_inverse_problem" width="128"/>
101
111
  </a>
102
112
  <br>
103
113
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
@@ -112,22 +122,40 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
112
122
  </p>
113
123
 
114
124
 
115
- ## Links
125
+ ## High-Performance
116
126
 
117
- * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
118
- * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
127
+ Jaxion is scalable on multiple GPUs!
119
128
 
129
+ <p align="center">
130
+ <a href="https://jaxion.readthedocs.io">
131
+ <img src="examples/soliton_binary_merger/timing.png" alt="timing" width="400"/>
132
+ </a>
133
+ </p>
120
134
 
121
- ## Testing
122
135
 
123
- Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
136
+ ## Contributing
124
137
 
138
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
125
139
 
126
- ## Contributing
127
140
 
128
- Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
141
+ ## Links
142
+
143
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
144
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using Jaxion.
129
145
 
130
146
 
131
147
  ## Cite this repository
132
148
 
133
- TODO XXX
149
+ If you use this software, please cite it as below.
150
+
151
+ ```bibtex
152
+ @software{Mocz_Jaxion_2025,
153
+ author = {Mocz, Philip},
154
+ doi = {10.5281/zenodo.17438467},
155
+ month = oct,
156
+ title = {{Jaxion}},
157
+ url = {https://github.com/JaxionProject/jaxion},
158
+ version = {0.0.4},
159
+ year = {2025}
160
+ }
161
+ ```
@@ -6,9 +6,9 @@ jaxion/__init__.py
6
6
  jaxion/analysis.py
7
7
  jaxion/constants.py
8
8
  jaxion/cosmology.py
9
+ jaxion/defaults.json
9
10
  jaxion/gravity.py
10
11
  jaxion/hydro.py
11
- jaxion/params_default.json
12
12
  jaxion/particles.py
13
13
  jaxion/quantum.py
14
14
  jaxion/simulation.py
@@ -19,4 +19,6 @@ jaxion.egg-info/SOURCES.txt
19
19
  jaxion.egg-info/dependency_links.txt
20
20
  jaxion.egg-info/requires.txt
21
21
  jaxion.egg-info/top_level.txt
22
+ tests/test_analysis.py
23
+ tests/test_cosmology.py
22
24
  tests/test_examples.py
@@ -10,7 +10,7 @@ license = "Apache-2.0"
10
10
  dynamic = ["version", "dependencies"]
11
11
 
12
12
  [tool.setuptools.packages.find]
13
- include = ["jaxion"]
13
+ include = ["jaxio*"]
14
14
  exclude = ["paper"]
15
15
 
16
16
  [tool.setuptools.package-data]
@@ -0,0 +1,24 @@
1
+ import jax.numpy as jnp
2
+ import jaxion
3
+ from jaxion.quantum import quantum_velocity
4
+ from jaxion.analysis import radial_power_spectrum
5
+ import pytest
6
+
7
+
8
+ def test_quantum_velocity_and_radial_power_spectrum():
9
+ sim = jaxion.Simulation({})
10
+ box_size = sim.box_size
11
+ m_per_hbar = sim.m_per_hbar
12
+ xx, yy, _ = sim.grid
13
+ kx, ky, kz = sim.kgrid
14
+
15
+ psi = (
16
+ jnp.cos(2.0 * jnp.pi * xx / box_size) ** 2
17
+ + jnp.cos(2.0 * jnp.pi * yy / box_size) * 1j
18
+ )
19
+
20
+ vx, _, _ = quantum_velocity(psi, box_size, m_per_hbar)
21
+
22
+ Pf, _, _ = radial_power_spectrum(vx, kx, ky, kz, box_size)
23
+
24
+ assert jnp.max(Pf) == pytest.approx(8244.606, rel=1e-4)
@@ -0,0 +1,29 @@
1
+ from jaxion.cosmology import (
2
+ get_physical_time_interval,
3
+ get_supercomoving_time_interval,
4
+ get_scale_factor,
5
+ get_next_scale_factor,
6
+ )
7
+ import pytest
8
+
9
+
10
+ def test_cosmology_functions():
11
+ z_start = 127.0
12
+ z_end = 0.0
13
+ omega_matter = 0.3
14
+ omega_lambda = 0.7
15
+ little_h = 0.7
16
+ dt_hat = 10.0
17
+
18
+ assert get_physical_time_interval(
19
+ z_start, z_end, omega_matter, omega_lambda, little_h
20
+ ) == pytest.approx(13.76084)
21
+ assert get_supercomoving_time_interval(
22
+ z_start, z_end, omega_matter, omega_lambda, little_h
23
+ ) == pytest.approx(530.44415)
24
+ assert get_scale_factor(
25
+ z_start, dt_hat, omega_matter, omega_lambda, little_h
26
+ ) == pytest.approx(0.008084139320999384)
27
+ assert get_next_scale_factor(
28
+ z_start, dt_hat, omega_matter, omega_lambda, little_h
29
+ ) == pytest.approx(0.00808401)
@@ -7,11 +7,22 @@ rel_tol = 1e-4
7
7
 
8
8
  def test_tidal_stripping():
9
9
  sim = run_example_main(
10
- "examples/tidal_stripping/tidal_stripping.py", argv=["--res", "1"]
10
+ "examples/tidal_stripping/tidal_stripping.py",
11
+ argv=["--res", "1", "--save", "False"],
11
12
  )
12
13
  assert sim.resolution == 32
13
14
  assert sim.state["t"] > 0.0
14
- assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(639.0479, rel=rel_tol)
15
+ assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(162.028, rel=rel_tol)
16
+
17
+
18
+ def test_tidal_stripping_distributed_emulate():
19
+ sim = run_example_main(
20
+ "examples/tidal_stripping/tidal_stripping.py",
21
+ argv=["--res", "1", "--distributed", "--emulate"],
22
+ )
23
+ assert sim.resolution == 32
24
+ assert sim.state["t"] > 0.0
25
+ assert jnp.mean(jnp.abs(sim.state["psi"])) == pytest.approx(162.028, rel=rel_tol)
15
26
 
16
27
 
17
28
  def test_heating_gas():
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes