jaxion 0.0.7__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxion-0.0.7 → jaxion-0.0.8}/PKG-INFO +1 -1
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/simulation.py +37 -9
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion.egg-info/PKG-INFO +1 -1
- {jaxion-0.0.7 → jaxion-0.0.8}/LICENSE +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/README.md +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/__init__.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/analysis.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/constants.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/cosmology.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/defaults.json +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/gravity.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/hydro.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/particles.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/quantum.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/utils.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion/visualization.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion.egg-info/SOURCES.txt +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion.egg-info/dependency_links.txt +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion.egg-info/requires.txt +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/jaxion.egg-info/top_level.txt +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/pyproject.toml +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/requirements.txt +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/setup.cfg +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/tests/test_analysis.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/tests/test_cosmology.py +0 -0
- {jaxion-0.0.7 → jaxion-0.0.8}/tests/test_examples.py +0 -0
|
@@ -42,7 +42,7 @@ class Simulation:
|
|
|
42
42
|
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
|
-
def __init__(self, params, sharding=None):
|
|
45
|
+
def __init__(self, params, sharding=None, checkpoint_number=None):
|
|
46
46
|
# allow loading directly from a checkpoint path
|
|
47
47
|
load_from_checkpoint = False
|
|
48
48
|
checkpoint_dir = ""
|
|
@@ -105,6 +105,7 @@ class Simulation:
|
|
|
105
105
|
self.custom_drift = None
|
|
106
106
|
self.custom_density = None
|
|
107
107
|
self.custom_plot = None
|
|
108
|
+
self.callback = None
|
|
108
109
|
|
|
109
110
|
# simulation state
|
|
110
111
|
self.state = {}
|
|
@@ -142,7 +143,10 @@ class Simulation:
|
|
|
142
143
|
async_checkpoint_manager = ocp.CheckpointManager(
|
|
143
144
|
checkpoint_dir, options=options
|
|
144
145
|
)
|
|
145
|
-
|
|
146
|
+
if checkpoint_number is not None:
|
|
147
|
+
step = checkpoint_number
|
|
148
|
+
else:
|
|
149
|
+
step = async_checkpoint_manager.latest_step()
|
|
146
150
|
self.state = async_checkpoint_manager.restore(
|
|
147
151
|
step, args=ocp.args.StandardRestore(self.state)
|
|
148
152
|
)
|
|
@@ -178,6 +182,23 @@ class Simulation:
|
|
|
178
182
|
"""
|
|
179
183
|
return self.box_size / self.resolution
|
|
180
184
|
|
|
185
|
+
@property
|
|
186
|
+
def nt(self):
|
|
187
|
+
"""
|
|
188
|
+
Return the number of timesteps
|
|
189
|
+
"""
|
|
190
|
+
dx = self.dx
|
|
191
|
+
m_per_hbar = self.m_per_hbar
|
|
192
|
+
safety = self.params["time"]["safety_factor"]
|
|
193
|
+
dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
|
|
194
|
+
t_start = self.params["time"]["start"]
|
|
195
|
+
t_end = self.params["time"]["end"]
|
|
196
|
+
t_span = t_end - t_start
|
|
197
|
+
num_checkpoints = self.params["output"]["num_checkpoints"]
|
|
198
|
+
nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
|
|
199
|
+
|
|
200
|
+
return nt
|
|
201
|
+
|
|
181
202
|
@property
|
|
182
203
|
def axion_mass(self):
|
|
183
204
|
"""
|
|
@@ -341,8 +362,6 @@ class Simulation:
|
|
|
341
362
|
num_cells = self.resolution**3
|
|
342
363
|
m_per_hbar = self.m_per_hbar
|
|
343
364
|
|
|
344
|
-
safety = self.params["time"]["safety_factor"]
|
|
345
|
-
dt_kin = safety * (m_per_hbar / 6.0) * (dx * dx)
|
|
346
365
|
t_start = self.params["time"]["start"]
|
|
347
366
|
t_end = self.params["time"]["end"]
|
|
348
367
|
t_span = t_end - t_start
|
|
@@ -360,6 +379,9 @@ class Simulation:
|
|
|
360
379
|
if use_custom:
|
|
361
380
|
custom_kick = self.custom_kick
|
|
362
381
|
custom_drift = self.custom_drift
|
|
382
|
+
use_callback = self.callback is not None
|
|
383
|
+
if use_callback:
|
|
384
|
+
custom_callback = self.callback
|
|
363
385
|
|
|
364
386
|
# cosmology
|
|
365
387
|
if use_cosmology:
|
|
@@ -390,7 +412,7 @@ class Simulation:
|
|
|
390
412
|
|
|
391
413
|
# round up to the nearest multiple of num_checkpoints
|
|
392
414
|
num_checkpoints = self.params["output"]["num_checkpoints"]
|
|
393
|
-
nt =
|
|
415
|
+
nt = self.nt
|
|
394
416
|
nt_sub = int(round(nt / num_checkpoints))
|
|
395
417
|
dt = t_span / nt
|
|
396
418
|
|
|
@@ -466,7 +488,7 @@ class Simulation:
|
|
|
466
488
|
|
|
467
489
|
return state
|
|
468
490
|
|
|
469
|
-
def _update(
|
|
491
|
+
def _update(i, carry):
|
|
470
492
|
# Update the simulation state by one timestep
|
|
471
493
|
# according to a 2nd-order `kick-drift-kick` scheme
|
|
472
494
|
state, kx, ky, kz, k_sq = carry
|
|
@@ -481,6 +503,9 @@ class Simulation:
|
|
|
481
503
|
state["redshift"] = 1.0 / scale_factor - 1.0
|
|
482
504
|
state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
483
505
|
|
|
506
|
+
if use_callback:
|
|
507
|
+
custom_callback(i, state)
|
|
508
|
+
|
|
484
509
|
return state, kx, ky, kz, k_sq
|
|
485
510
|
|
|
486
511
|
# save initial state
|
|
@@ -490,16 +515,19 @@ class Simulation:
|
|
|
490
515
|
with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
|
|
491
516
|
json.dump(self.params, f, indent=2)
|
|
492
517
|
async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
|
|
518
|
+
async_checkpoint_manager.wait_until_finished()
|
|
493
519
|
plot_sim(state, checkpoint_dir, 0, self.params)
|
|
494
520
|
if self.custom_plot is not None:
|
|
495
521
|
self.custom_plot(state, checkpoint_dir, 0, self.params)
|
|
496
|
-
async_checkpoint_manager.wait_until_finished()
|
|
497
522
|
|
|
498
523
|
# Simulation Main Loop
|
|
499
524
|
t_start_timer = time.time()
|
|
500
525
|
if save:
|
|
501
526
|
for i in range(1, num_checkpoints + 1):
|
|
502
|
-
|
|
527
|
+
i_timestep = (i - 1) * nt_sub
|
|
528
|
+
carry = jax.lax.fori_loop(
|
|
529
|
+
i_timestep, i_timestep + nt_sub, _update, init_val=carry
|
|
530
|
+
)
|
|
503
531
|
state, _, _, _, _ = carry
|
|
504
532
|
jax.block_until_ready(state)
|
|
505
533
|
# save state
|
|
@@ -513,10 +541,10 @@ class Simulation:
|
|
|
513
541
|
print(
|
|
514
542
|
f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
|
|
515
543
|
)
|
|
544
|
+
async_checkpoint_manager.wait_until_finished()
|
|
516
545
|
plot_sim(state, checkpoint_dir, i, self.params)
|
|
517
546
|
if self.custom_plot is not None:
|
|
518
547
|
self.custom_plot(state, checkpoint_dir, i, self.params)
|
|
519
|
-
async_checkpoint_manager.wait_until_finished()
|
|
520
548
|
else:
|
|
521
549
|
carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
|
|
522
550
|
state, _, _, _, _ = carry
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|