jaxion 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxion/simulation.py CHANGED
@@ -14,6 +14,7 @@ from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
14
14
  from .utils import (
15
15
  set_up_parameters,
16
16
  print_parameters,
17
+ print_distributed_info,
17
18
  xmeshgrid,
18
19
  xmeshgrid_transpose,
19
20
  xzeros,
@@ -72,10 +73,12 @@ class Simulation:
72
73
  "hydro/particles sharding is not yet implemented."
73
74
  )
74
75
 
75
- if self.params["output"]["save"]:
76
- if jax.process_index() == 0:
77
- print("Simulation parameters:")
78
- print_parameters(self.params)
76
+ # print info
77
+ if jax.process_index() == 0:
78
+ print("Simulation parameters:")
79
+ print_parameters(self.params)
80
+ if sharding is not None:
81
+ print_distributed_info()
79
82
 
80
83
  # jitted functions
81
84
  self.xmeshgrid_jit = jax.jit(
@@ -170,6 +173,13 @@ class Simulation:
170
173
  / constants["speed_of_light"] ** 2
171
174
  )
172
175
 
176
+ @property
177
+ def m_per_hbar(self):
178
+ """
179
+ Return the mass per hbar in the simulation (M_sun / hbar)
180
+ """
181
+ return self.axion_mass / constants["reduced_planck_constant"]
182
+
173
183
  @property
174
184
  def sound_speed(self):
175
185
  """
@@ -212,8 +222,7 @@ class Simulation:
212
222
  """
213
223
  Return the dark matter velocity field from the wavefunction
214
224
  """
215
- m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
216
- return quantum_velocity(self.state["psi"], self.box_size, m_per_hbar)
225
+ return quantum_velocity(self.state["psi"], self.box_size, self.m_per_hbar)
217
226
 
218
227
  def _calc_rho_bar(self, state):
219
228
  rho_bar = 0.0
@@ -271,7 +280,7 @@ class Simulation:
271
280
 
272
281
  # Simulation parameters
273
282
  dx = self.dx
274
- m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
283
+ m_per_hbar = self.m_per_hbar
275
284
  box_size = self.box_size
276
285
 
277
286
  dt_fac = 1.0
@@ -408,14 +417,17 @@ class Simulation:
408
417
  elapsed = time.time() - t_start_timer
409
418
  est_total = elapsed / i * num_checkpoints
410
419
  est_remaining = est_total - elapsed
420
+ num_cells = self.resolution**3
421
+ mcups = (num_cells * (i * nt_sub)) / (elapsed * 1.0e6)
411
422
  if jax.process_index() == 0:
412
423
  print(
413
- f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
424
+ f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
414
425
  )
415
426
  plot_sim(state, checkpoint_dir, i, self.params)
416
427
  async_checkpoint_manager.wait_until_finished()
417
428
  else:
418
- state = jax.lax.fori_loop(0, nt, _update, init_val=state)
429
+ carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
430
+ state, kx, ky, kz, k_sq = carry
419
431
  jax.block_until_ready(state)
420
432
  if jax.process_index() == 0:
421
433
  print("Simulation Run Time (s): ", time.time() - t_start_timer)
jaxion/utils.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  import importlib.resources
6
6
  import json
7
7
  from importlib.metadata import version
8
+ import jax
8
9
  import jax.numpy as jnp
9
10
 
10
11
 
@@ -12,6 +13,22 @@ def print_parameters(params):
12
13
  print(json.dumps(params, indent=2))
13
14
 
14
15
 
16
+ def print_distributed_info():
17
+ for env_var in [
18
+ "SLURM_JOB_ID",
19
+ "SLURM_NTASKS",
20
+ "SLURM_NODELIST",
21
+ "SLURM_STEP_NODELIST",
22
+ "SLURM_STEP_GPUS",
23
+ "SLURM_GPUS",
24
+ ]:
25
+ print(f"{env_var}: {os.getenv(env_var, '')}")
26
+ print("Total number of processes: ", jax.process_count())
27
+ print("Total number of devices: ", jax.device_count())
28
+ print("List of devices: ", jax.devices())
29
+ print("Number of devices on this process: ", jax.local_device_count())
30
+
31
+
15
32
  def set_up_parameters(user_overwrites):
16
33
  # first load the default params
17
34
  params_path = importlib.resources.files("jaxion") / "params_default.json"
jaxion/visualization.py CHANGED
@@ -13,15 +13,19 @@ def plot_sim(state, checkpoint_dir, i, params):
13
13
  if params["physics"]["quantum"]:
14
14
  nx = state["psi"].shape[0]
15
15
  rho_bar_dm = jnp.mean(jnp.abs(state["psi"]) ** 2)
16
- rho_proj_dm = jax.experimental.multihost_utils.process_allgather(
17
- jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2).T)
18
- ).reshape(nx, nx)
16
+ rho_proj_dm = jnp.log10(
17
+ jax.experimental.multihost_utils.process_allgather(
18
+ jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2), tiled=True
19
+ )
20
+ ).T
19
21
  if params["physics"]["hydro"]:
20
22
  nx = state["rho"].shape[0]
21
23
  rho_bar_gas = jnp.mean(state["rho"])
22
- rho_proj_gas = jax.experimental.multihost_utils.process_allgather(
23
- jnp.log10(jnp.mean(state["rho"], axis=2).T)
24
- ).reshape(nx, nx)
24
+ rho_proj_gas = jnp.log10(
25
+ jax.experimental.multihost_utils.process_allgather(
26
+ jnp.mean(state["rho"], axis=2), tiled=True
27
+ )
28
+ ).T
25
29
 
26
30
  # create plot on process 0
27
31
  if jax.process_index() == 0:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -33,8 +33,11 @@ Dynamic: license-file
33
33
  [![PyPI Version Status][pypi-badge]][pypi-link]
34
34
  [![Test Status][workflow-test-badge]][workflow-test-link]
35
35
  [![Coverage][coverage-badge]][coverage-link]
36
+ [![Ruff][ruff-badge]][ruff-link]
37
+ [![asv][asv-badge]][asv-link]
36
38
  [![Readthedocs Status][docs-badge]][docs-link]
37
39
  [![License][license-badge]][license-link]
40
+ [![Software DOI][software-doi-badge]][software-doi-link]
38
41
 
39
42
  [status-link]: https://www.repostatus.org/#active
40
43
  [status-badge]: https://www.repostatus.org/badges/latest/active.svg
@@ -44,10 +47,17 @@ Dynamic: license-file
44
47
  [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
45
48
  [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
46
49
  [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
50
+ [ruff-link]: https://github.com/astral-sh/ruff
51
+ [ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
52
+ [asv-link]: https://jaxionproject.github.io/jaxion-benchmarks/
53
+ [asv-badge]: https://img.shields.io/badge/benchmarked%20by-asv-blue.svg?style=flat
47
54
  [docs-link]: https://jaxion.readthedocs.io
48
55
  [docs-badge]: https://readthedocs.org/projects/jaxion/badge
49
56
  [license-link]: https://opensource.org/licenses/Apache-2.0
50
57
  [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
58
+ [software-doi-link]: https://doi.org/10.5281/zenodo.17438467
59
+ [software-doi-badge]: https://zenodo.org/badge/1072645376.svg
60
+
51
61
 
52
62
  A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
53
63
 
@@ -96,8 +106,8 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
96
106
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
97
107
  <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
98
108
  </a>
99
- <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
100
- <img src="examples/logo/movie.gif" alt="logo" width="128"/>
109
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo_inverse_problem">
110
+ <img src="examples/logo_inverse_problem/movie.gif" alt="logo_inverse_problem" width="128"/>
101
111
  </a>
102
112
  <br>
103
113
  <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
@@ -112,22 +122,40 @@ Check out the `examples/` directory for demonstrations of using Jaxion.
112
122
  </p>
113
123
 
114
124
 
115
- ## Links
125
+ ## High-Performance
116
126
 
117
- * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
118
- * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
127
+ Jaxion is scalable on multiple GPUs!
119
128
 
129
+ <p align="center">
130
+ <a href="https://jaxion.readthedocs.io">
131
+ <img src="examples/soliton_binary_merger/timing.png" alt="timing" width="400"/>
132
+ </a>
133
+ </p>
120
134
 
121
- ## Testing
122
135
 
123
- Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
136
+ ## Contributing
124
137
 
138
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
125
139
 
126
- ## Contributing
127
140
 
128
- Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
141
+ ## Links
142
+
143
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
144
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using Jaxion.
129
145
 
130
146
 
131
147
  ## Cite this repository
132
148
 
133
- TODO XXX
149
+ If you use this software, please cite it as below.
150
+
151
+ ```bibtex
152
+ @software{Mocz_Jaxion_2025,
153
+ author = {Mocz, Philip},
154
+ doi = {10.5281/zenodo.17438467},
155
+ month = oct,
156
+ title = {{Jaxion}},
157
+ url = {https://github.com/JaxionProject/jaxion},
158
+ version = {0.0.4},
159
+ year = {2025}
160
+ }
161
+ ```
@@ -7,11 +7,11 @@ jaxion/hydro.py,sha256=KoJ02tRpAc4V3Ofzw4zbHLRaE2GdIatbOBE04_LsSRw,6980
7
7
  jaxion/params_default.json,sha256=9CJrhEPsv5zGEs7_WqFyuccCDipPCDhXgKzVdqOsOWE,2775
8
8
  jaxion/particles.py,sha256=pMopGvoZ0J_3EviD0WnTMmiebU9h2_8IO-p6I-E5DEU,3980
9
9
  jaxion/quantum.py,sha256=GWOpN6ipfEw-6Ah2zQpxS3oqeSt_iHMDSgnVYSjXY5E,3321
10
- jaxion/simulation.py,sha256=s6gCAt-gAoN5d46vcdxoqtn4TwsrfNGb4Cq-2p_JxsI,15927
11
- jaxion/utils.py,sha256=rT7NM0FNEgFwN7oTgTb-jkR66Iw0xYTHHxcoikYd1ag,3572
12
- jaxion/visualization.py,sha256=K5EQOHPfj7LF29fW_naWH8a7TEyEa3wIaQw7rpebx0w,2914
13
- jaxion-0.0.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
- jaxion-0.0.4.dist-info/METADATA,sha256=66lU0x1ZofP-uCwD4hF0C45Ao-pNhPvMzr6Krs__Hws,5305
15
- jaxion-0.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- jaxion-0.0.4.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
17
- jaxion-0.0.4.dist-info/RECORD,,
10
+ jaxion/simulation.py,sha256=2YkHh3tUVvg5tUNnOxf4s7wGeuMntYUJcJWV0M-3Pl8,16267
11
+ jaxion/utils.py,sha256=f8SvJjqzcW2K91qbPNqrsjfVjyPShuf50yoSHc0YqYE,4093
12
+ jaxion/visualization.py,sha256=Vx_xuEE7BB1AkEGaY9KHNSIIDxJzUVT104o-3uglW8o,2966
13
+ jaxion-0.0.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
+ jaxion-0.0.5.dist-info/METADATA,sha256=Ot_IWpsCk6l7KX94fCs4j1LLcJQLlQFmp-jYCyj_yWY,6378
15
+ jaxion-0.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ jaxion-0.0.5.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
17
+ jaxion-0.0.5.dist-info/RECORD,,
File without changes