jaxion 0.0.7__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {jaxion-0.0.7 → jaxion-0.0.9}/PKG-INFO +7 -6
  2. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/simulation.py +37 -9
  3. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion.egg-info/PKG-INFO +7 -6
  4. jaxion-0.0.9/jaxion.egg-info/requires.txt +12 -0
  5. {jaxion-0.0.7 → jaxion-0.0.9}/pyproject.toml +1 -1
  6. jaxion-0.0.9/requirements.txt +9 -0
  7. jaxion-0.0.7/jaxion.egg-info/requires.txt +0 -11
  8. jaxion-0.0.7/requirements.txt +0 -8
  9. {jaxion-0.0.7 → jaxion-0.0.9}/LICENSE +0 -0
  10. {jaxion-0.0.7 → jaxion-0.0.9}/README.md +0 -0
  11. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/__init__.py +0 -0
  12. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/analysis.py +0 -0
  13. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/constants.py +0 -0
  14. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/cosmology.py +0 -0
  15. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/defaults.json +0 -0
  16. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/gravity.py +0 -0
  17. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/hydro.py +0 -0
  18. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/particles.py +0 -0
  19. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/quantum.py +0 -0
  20. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/utils.py +0 -0
  21. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion/visualization.py +0 -0
  22. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion.egg-info/SOURCES.txt +0 -0
  23. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion.egg-info/dependency_links.txt +0 -0
  24. {jaxion-0.0.7 → jaxion-0.0.9}/jaxion.egg-info/top_level.txt +0 -0
  25. {jaxion-0.0.7 → jaxion-0.0.9}/setup.cfg +0 -0
  26. {jaxion-0.0.7 → jaxion-0.0.9}/tests/test_analysis.py +0 -0
  27. {jaxion-0.0.7 → jaxion-0.0.9}/tests/test_cosmology.py +0 -0
  28. {jaxion-0.0.7 → jaxion-0.0.9}/tests/test_examples.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -9,16 +9,17 @@ Project-URL: Homepage, https://github.com/JaxionProject/jaxion
9
9
  Requires-Python: >=3.11
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
- Requires-Dist: jax==0.5.3
13
- Requires-Dist: jaxdecomp==0.2.7
12
+ Requires-Dist: jax==0.6.0
13
+ Requires-Dist: jaxdecomp==0.2.9
14
14
  Requires-Dist: tensorflow
15
- Requires-Dist: orbax-checkpoint==0.11.18
16
- Requires-Dist: optax==0.2.5
15
+ Requires-Dist: orbax-checkpoint==0.11.25
16
+ Requires-Dist: optax==0.2.6
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: matplotlib
19
19
  Requires-Dist: setuptools>=70.1.1
20
+ Requires-Dist: chex==0.1.90
20
21
  Provides-Extra: cuda12
21
- Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
22
+ Requires-Dist: jax[cuda12]==0.6.0; extra == "cuda12"
22
23
  Dynamic: license-file
23
24
 
24
25
  <p align="center">
@@ -42,7 +42,7 @@ class Simulation:
42
42
 
43
43
  """
44
44
 
45
- def __init__(self, params, sharding=None):
45
+ def __init__(self, params, sharding=None, checkpoint_number=None):
46
46
  # allow loading directly from a checkpoint path
47
47
  load_from_checkpoint = False
48
48
  checkpoint_dir = ""
@@ -105,6 +105,7 @@ class Simulation:
105
105
  self.custom_drift = None
106
106
  self.custom_density = None
107
107
  self.custom_plot = None
108
+ self.callback = None
108
109
 
109
110
  # simulation state
110
111
  self.state = {}
@@ -142,7 +143,10 @@ class Simulation:
142
143
  async_checkpoint_manager = ocp.CheckpointManager(
143
144
  checkpoint_dir, options=options
144
145
  )
145
- step = async_checkpoint_manager.latest_step()
146
+ if checkpoint_number is not None:
147
+ step = checkpoint_number
148
+ else:
149
+ step = async_checkpoint_manager.latest_step()
146
150
  self.state = async_checkpoint_manager.restore(
147
151
  step, args=ocp.args.StandardRestore(self.state)
148
152
  )
@@ -178,6 +182,23 @@ class Simulation:
178
182
  """
179
183
  return self.box_size / self.resolution
180
184
 
185
+ @property
186
+ def nt(self):
187
+ """
188
+ Return the number of timesteps
189
+ """
190
+ dx = self.dx
191
+ m_per_hbar = self.m_per_hbar
192
+ safety = self.params["time"]["safety_factor"]
193
+ dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
194
+ t_start = self.params["time"]["start"]
195
+ t_end = self.params["time"]["end"]
196
+ t_span = t_end - t_start
197
+ num_checkpoints = self.params["output"]["num_checkpoints"]
198
+ nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
199
+
200
+ return nt
201
+
181
202
  @property
182
203
  def axion_mass(self):
183
204
  """
@@ -341,8 +362,6 @@ class Simulation:
341
362
  num_cells = self.resolution**3
342
363
  m_per_hbar = self.m_per_hbar
343
364
 
344
- safety = self.params["time"]["safety_factor"]
345
- dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
346
365
  t_start = self.params["time"]["start"]
347
366
  t_end = self.params["time"]["end"]
348
367
  t_span = t_end - t_start
@@ -360,6 +379,9 @@ class Simulation:
360
379
  if use_custom:
361
380
  custom_kick = self.custom_kick
362
381
  custom_drift = self.custom_drift
382
+ use_callback = self.callback is not None
383
+ if use_callback:
384
+ custom_callback = self.callback
363
385
 
364
386
  # cosmology
365
387
  if use_cosmology:
@@ -390,7 +412,7 @@ class Simulation:
390
412
 
391
413
  # round up to the nearest multiple of num_checkpoints
392
414
  num_checkpoints = self.params["output"]["num_checkpoints"]
393
- nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
415
+ nt = self.nt
394
416
  nt_sub = int(round(nt / num_checkpoints))
395
417
  dt = t_span / nt
396
418
 
@@ -466,7 +488,7 @@ class Simulation:
466
488
 
467
489
  return state
468
490
 
469
- def _update(_, carry):
491
+ def _update(i, carry):
470
492
  # Update the simulation state by one timestep
471
493
  # according to a 2nd-order `kick-drift-kick` scheme
472
494
  state, kx, ky, kz, k_sq = carry
@@ -481,6 +503,9 @@ class Simulation:
481
503
  state["redshift"] = 1.0 / scale_factor - 1.0
482
504
  state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
483
505
 
506
+ if use_callback:
507
+ custom_callback(i, state)
508
+
484
509
  return state, kx, ky, kz, k_sq
485
510
 
486
511
  # save initial state
@@ -490,16 +515,19 @@ class Simulation:
490
515
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
491
516
  json.dump(self.params, f, indent=2)
492
517
  async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
518
+ async_checkpoint_manager.wait_until_finished()
493
519
  plot_sim(state, checkpoint_dir, 0, self.params)
494
520
  if self.custom_plot is not None:
495
521
  self.custom_plot(state, checkpoint_dir, 0, self.params)
496
- async_checkpoint_manager.wait_until_finished()
497
522
 
498
523
  # Simulation Main Loop
499
524
  t_start_timer = time.time()
500
525
  if save:
501
526
  for i in range(1, num_checkpoints + 1):
502
- carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
527
+ i_timestep = (i - 1) * nt_sub
528
+ carry = jax.lax.fori_loop(
529
+ i_timestep, i_timestep + nt_sub, _update, init_val=carry
530
+ )
503
531
  state, _, _, _, _ = carry
504
532
  jax.block_until_ready(state)
505
533
  # save state
@@ -513,10 +541,10 @@ class Simulation:
513
541
  print(
514
542
  f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
515
543
  )
544
+ async_checkpoint_manager.wait_until_finished()
516
545
  plot_sim(state, checkpoint_dir, i, self.params)
517
546
  if self.custom_plot is not None:
518
547
  self.custom_plot(state, checkpoint_dir, i, self.params)
519
- async_checkpoint_manager.wait_until_finished()
520
548
  else:
521
549
  carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
522
550
  state, _, _, _, _ = carry
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -9,16 +9,17 @@ Project-URL: Homepage, https://github.com/JaxionProject/jaxion
9
9
  Requires-Python: >=3.11
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
- Requires-Dist: jax==0.5.3
13
- Requires-Dist: jaxdecomp==0.2.7
12
+ Requires-Dist: jax==0.6.0
13
+ Requires-Dist: jaxdecomp==0.2.9
14
14
  Requires-Dist: tensorflow
15
- Requires-Dist: orbax-checkpoint==0.11.18
16
- Requires-Dist: optax==0.2.5
15
+ Requires-Dist: orbax-checkpoint==0.11.25
16
+ Requires-Dist: optax==0.2.6
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: matplotlib
19
19
  Requires-Dist: setuptools>=70.1.1
20
+ Requires-Dist: chex==0.1.90
20
21
  Provides-Extra: cuda12
21
- Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
22
+ Requires-Dist: jax[cuda12]==0.6.0; extra == "cuda12"
22
23
  Dynamic: license-file
23
24
 
24
25
  <p align="center">
@@ -0,0 +1,12 @@
1
+ jax==0.6.0
2
+ jaxdecomp==0.2.9
3
+ tensorflow
4
+ orbax-checkpoint==0.11.25
5
+ optax==0.2.6
6
+ numpy
7
+ matplotlib
8
+ setuptools>=70.1.1
9
+ chex==0.1.90
10
+
11
+ [cuda12]
12
+ jax[cuda12]==0.6.0
@@ -25,7 +25,7 @@ dependencies = {file = ["requirements.txt"]}
25
25
 
26
26
  [project.optional-dependencies]
27
27
  cuda12 = [
28
- "jax[cuda12]==0.5.3",
28
+ "jax[cuda12]==0.6.0",
29
29
  ]
30
30
 
31
31
  [tool.setuptools-git-versioning]
@@ -0,0 +1,9 @@
1
+ jax==0.6.0
2
+ jaxdecomp==0.2.9
3
+ tensorflow
4
+ orbax-checkpoint==0.11.25
5
+ optax==0.2.6
6
+ numpy
7
+ matplotlib
8
+ setuptools>=70.1.1
9
+ chex==0.1.90
@@ -1,11 +0,0 @@
1
- jax==0.5.3
2
- jaxdecomp==0.2.7
3
- tensorflow
4
- orbax-checkpoint==0.11.18
5
- optax==0.2.5
6
- numpy
7
- matplotlib
8
- setuptools>=70.1.1
9
-
10
- [cuda12]
11
- jax[cuda12]==0.5.3
@@ -1,8 +0,0 @@
1
- jax==0.5.3
2
- jaxdecomp==0.2.7
3
- tensorflow
4
- orbax-checkpoint==0.11.18
5
- optax==0.2.5
6
- numpy
7
- matplotlib
8
- setuptools>=70.1.1
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes