jaxion 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxion/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  from .simulation import Simulation as Simulation
2
2
  from .constants import constants as constants
3
+ from .analysis import radial_power_spectrum as radial_power_spectrum
jaxion/analysis.py ADDED
@@ -0,0 +1,58 @@
1
+ import jax.numpy as jnp
2
+ import jaxdecomp as jd
3
+
4
+
5
+ def radial_power_spectrum(data_cube, kx, ky, kz, box_size):
6
+ """
7
+ Computes the radially averaged power spectral density of a 3D datacube.
8
+
9
+ Parameters
10
+ ----------
11
+ data_cube : jnp.ndarray
12
+ 3D data cube, must be cubic
13
+ kx, ky, kz: jnp.ndarray
14
+ wavenumber grids in each dimension
15
+ box_size: float
16
+ physical size of box
17
+
18
+ Returns
19
+ -------
20
+ Pf: jnp.ndarray
21
+ radial power spectrum
22
+ k: jnp.ndarray
23
+ wavenumbers
24
+ total_power: float
25
+ total power
26
+ """
27
+ dim = data_cube.ndim
28
+ nx = data_cube.shape[0]
29
+ dx = box_size / nx
30
+
31
+ # Compute power spectrum
32
+ data_cube_hat = jd.fft.pfft3d(data_cube)
33
+ total_power = 0.5 * jnp.sum(jnp.abs(data_cube_hat) ** 2) / nx**dim * dx**dim
34
+ phi_k = 0.5 * jnp.abs(data_cube_hat) ** 2 / nx**dim * dx**dim
35
+ half_size = nx // 2 + 1
36
+
37
+ # Compute radially-averaged power spectrum
38
+ # if dim == 2:
39
+ # k_r = jnp.sqrt(kx**2 + ky**2)
40
+ k_r = jnp.sqrt(kx**2 + ky**2 + kz**2)
41
+
42
+ Pf, _ = jnp.histogram(
43
+ k_r, range=(-0.5, half_size - 0.5), bins=half_size, weights=phi_k
44
+ )
45
+ norm, _ = jnp.histogram(k_r, range=(-0.5, half_size - 0.5), bins=half_size)
46
+ Pf /= norm + (norm == 0)
47
+
48
+ k = 2.0 * jnp.pi * jnp.arange(half_size) / box_size
49
+ dk = 2.0 * jnp.pi / box_size
50
+
51
+ Pf /= dk**dim
52
+
53
+ # Add geometrical factor
54
+ # if dim == 2:
55
+ # Pf = Pf * 2.0 * jnp.pi * k
56
+ Pf *= 4.0 * jnp.pi * k**2
57
+
58
+ return Pf, k, total_power
jaxion/constants.py CHANGED
@@ -1,13 +1,13 @@
1
1
  # Copyright (c) 2025 The Jaxion Team.
2
2
 
3
- # Physical Constants
4
- #
5
- # Jaxion uses a unit system of:
6
- # [L] = kpc
7
- # [V] = km/s
8
- # [M] = Msun
9
- #
10
- # other units are derived from these base units, e.g., [T] = [L]/[V] = kpc / (km/s) ~= 0.978 Gyr
3
+ """
4
+ Physical constants in units of:
5
+ [L] = kpc,
6
+ [V] = km/s,
7
+ [M] = Msun
8
+
9
+ note: other units are derived from these base units, e.g., [T] = [L]/[V] = kpc / (km/s) ~= 0.978 Gyr
10
+ """
11
11
 
12
12
  constants = {
13
13
  "gravitational_constant": 4.30241002e-6, # G / (kpc * (km/s)^2 / Msun)
jaxion/cosmology.py ADDED
@@ -0,0 +1,74 @@
1
+ import jax.numpy as jnp
2
+
3
+ # Pure functions for cosmology simulation
4
+
5
+
6
+ def get_physical_time_interval(z_start, z_end, omega_matter, omega_lambda, little_h):
7
+ """Compute the total physical time between two redshifts."""
8
+ # da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
9
+ a_start = 1.0 / (1.0 + z_start)
10
+ a_end = 1.0 / (1.0 + z_end)
11
+ H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
12
+ n_quad = 10000
13
+ a_lin = jnp.linspace(a_start, a_end, n_quad)
14
+ a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
15
+ dt_hat = jnp.trapezoid(1.0 / a_dot, a_lin)
16
+ return dt_hat
17
+
18
+
19
+ def get_supercomoving_time_interval(
20
+ z_start, z_end, omega_matter, omega_lambda, little_h
21
+ ):
22
+ """Compute the total supercomoving time (dt_hat = a^-2 dt) between two redshifts."""
23
+ # da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
24
+ # da/dt_hat = a^2 * da/dt
25
+ # dt_hat/da = a^-2 / (da/dt)
26
+ a_start = 1.0 / (1.0 + z_start)
27
+ a_end = 1.0 / (1.0 + z_end)
28
+ H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
29
+ n_quad = 10000
30
+ a_lin = jnp.linspace(a_start, a_end, n_quad)
31
+ a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
32
+ dt_hat_da = a_lin**-2 / a_dot
33
+ dt_hat = jnp.trapezoid(dt_hat_da, a_lin)
34
+ return dt_hat
35
+
36
+
37
+ def get_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
38
+ """Compute the scale factor corresponding to a given supercomoving time,
39
+ by root finding."""
40
+ a_start = 1.0 / (1.0 + z_start)
41
+ a = a_start
42
+ tolerance = 1e-6
43
+ max_iterations = 100
44
+ lower_bound = 0.0
45
+ upper_bound = 2.0 # Set a reasonable upper bound for the scale factor
46
+ for _ in range(max_iterations):
47
+ dt_hat_guess = get_supercomoving_time_interval(
48
+ z_start, 1.0 / a - 1.0, omega_matter, omega_lambda, little_h
49
+ )
50
+ error = dt_hat_guess - dt_hat
51
+ if jnp.abs(error) < tolerance:
52
+ break
53
+ # Use bisection method
54
+ if error > 0:
55
+ upper_bound = a # Move upper bound down
56
+ else:
57
+ lower_bound = a # Move lower bound up
58
+ a = (lower_bound + upper_bound) / 2 # Update scale factor to midpoint
59
+ return a
60
+
61
+
62
+ def get_next_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
63
+ """Advance scale factor by dt_hat using RK2 (midpoint) on da/dt_hat = a^2 * da/dt."""
64
+ a_start = 1.0 / (1.0 + z_start)
65
+ H0 = 0.1 * little_h
66
+
67
+ def g(a):
68
+ a_dot = H0 * jnp.sqrt((omega_matter / a) + (omega_lambda * a**2))
69
+ return a**2 * a_dot # da/dt_hat
70
+
71
+ k1 = g(a_start)
72
+ k2 = g(a_start + 0.5 * dt_hat * k1)
73
+ a = a_start + dt_hat * k2
74
+ return a
jaxion/gravity.py CHANGED
@@ -1,9 +1,10 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for gravity calculations
4
5
 
5
6
 
6
7
  def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
7
- Vhat = -jnp.fft.fftn(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
8
- V = jnp.real(jnp.fft.ifftn(Vhat))
8
+ Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
+ V = jnp.real(jd.fft.pifft3d(Vhat))
9
10
  return V
jaxion/hydro.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for hydro simulation
4
5
 
@@ -89,11 +90,11 @@ def get_flux(rho_L, vx_L, vy_L, vz_L, rho_R, vx_R, vy_R, vz_R, cs):
89
90
 
90
91
 
91
92
  def hydro_accelerate(vx, vy, vz, V, kx, ky, kz, dt):
92
- V_hat = jnp.fft.fftn(V)
93
+ V_hat = jd.fft.pfft3d(V)
93
94
 
94
- ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
95
- ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
96
- az = -jnp.real(jnp.fft.ifftn(1.0j * kz * V_hat))
95
+ ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
96
+ ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
97
+ az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
97
98
 
98
99
  vx += ax * dt
99
100
  vy += ay * dt
@@ -28,7 +28,7 @@
28
28
  "domain": {
29
29
  "box_size": {
30
30
  "default": 10.0,
31
- "description": "periodic domain box size (kpc)."
31
+ "description": "periodic domain box size [kpc]."
32
32
  },
33
33
  "resolution_base": {
34
34
  "default": 32,
@@ -40,38 +40,75 @@
40
40
  }
41
41
  },
42
42
  "time": {
43
- "start": 0.0,
44
- "end": 1.0,
45
- "adaptive": false
43
+ "start": {
44
+ "default": 0.0,
45
+ "description": "simulation start time [kpc/(km/s)] or [redshift] (cosmology=true)."
46
+ },
47
+ "end": {
48
+ "default": 1.0,
49
+ "description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
50
+ },
51
+ "adaptive": {
52
+ "default": false,
53
+ "description": "switch on for adaptive time stepping."
54
+ }
46
55
  },
47
56
  "output": {
48
- "path": "./checkpoints",
49
- "num_checkpoints": 100,
50
- "save": true,
51
- "plot_dynamic_range": 100.0
57
+ "path": {
58
+ "default": "./checkpoints",
59
+ "description": "path to output directory."
60
+ },
61
+ "num_checkpoints": {
62
+ "default": 100,
63
+ "description": "number of checkpoints to save."
64
+ },
65
+ "save": {
66
+ "default": true,
67
+ "description": "switch on to save checkpoints."
68
+ },
69
+ "plot_dynamic_range": {
70
+ "default": 100.0,
71
+ "description": "dynamic range for plotting."
72
+ }
52
73
  },
53
74
  "quantum": {
54
75
  "m_22": {
55
76
  "default": 1.0,
56
- "description": "axion mass in units of 10^{-22} eV."
77
+ "description": "axion mass [10^{-22} eV]."
57
78
  }
58
79
  },
59
80
  "hydro": {
60
81
  "sound_speed": {
61
82
  "default": 1.0,
62
- "description": "Isothermal sound speed (km/s)."
83
+ "description": "isothermal sound speed [km/s]."
63
84
  }
64
85
  },
65
86
  "particles": {
66
- "num_particles": 0,
67
- "particle_mass": 1.0
87
+ "num_particles": {
88
+ "default": 0,
89
+ "description": "number of particles."
90
+ },
91
+ "particle_mass": {
92
+ "default": 1.0,
93
+ "description": "particle mass [M_sun]."
94
+ }
68
95
  },
69
96
  "cosmology": {
70
- "Omega_m": 0.3,
71
- "Omega_Lambda": 0.7
97
+ "omega_matter": {
98
+ "default": 0.3,
99
+ "description": "matter density parameter."
100
+ },
101
+ "omega_lambda": {
102
+ "default": 0.7,
103
+ "description": "dark energy density parameter."
104
+ },
105
+ "little_h": {
106
+ "default": 0.7,
107
+ "description": "Hubble parameter little h (H_0 = 100 h km/s/Mpc)."
108
+ }
72
109
  },
73
110
  "version": {
74
111
  "default": "unknown",
75
- "description": "version of the jaxion library used to create these parameters (automatically detected)."
112
+ "description": "jaxion version used (auto detected)."
76
113
  }
77
114
  }
jaxion/particles.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
+ import jaxdecomp as jd
3
4
 
4
5
  # Pure functions for particle-mesh calculations (stars, BHs, ...)
5
6
 
@@ -64,10 +65,10 @@ def get_acceleration(pos, V, kx, ky, kz, dx):
64
65
  i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
65
66
 
66
67
  # find accelerations on the grid
67
- V_hat = jnp.fft.fftn(V)
68
- ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
69
- ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
70
- az = -jnp.real(jnp.fft.ifftn(1.0j * kz * V_hat))
68
+ V_hat = jd.fft.pfft3d(V)
69
+ ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
70
+ ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
71
+ az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
71
72
  a_grid = jnp.stack((ax, ay, az), axis=-1)
72
73
 
73
74
  # interpolate the accelerations to the particle positions
jaxion/quantum.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for quantum simulation
4
5
 
@@ -9,7 +10,86 @@ def quantum_kick(psi, V, m_per_hbar, dt):
9
10
 
10
11
 
11
12
  def quantum_drift(psi, k_sq, m_per_hbar, dt):
12
- psi_hat = jnp.fft.fftn(psi)
13
+ psi_hat = jd.fft.pfft3d(psi)
13
14
  psi_hat = jnp.exp(dt * (-1.0j * k_sq / m_per_hbar / 2.0)) * psi_hat
14
- psi = jnp.fft.ifftn(psi_hat)
15
+ psi = jd.fft.pifft3d(psi_hat)
15
16
  return psi
17
+
18
+
19
+ def get_gradient(psi, kx, ky, kz):
20
+ """
21
+ Computes gradient of wavefunction psi
22
+ psi: jnp.ndarray (3D)
23
+ Returns: grad_psi_x, grad_psi_y, grad_psi_z
24
+ """
25
+ psi_hat = jd.fft.pfft3d(psi)
26
+ grad_psi_x_hat = 1.0j * kx * psi_hat
27
+ grad_psi_y_hat = 1.0j * ky * psi_hat
28
+ grad_psi_z_hat = 1.0j * kz * psi_hat
29
+
30
+ grad_psi_x = jd.fft.pifft3d(grad_psi_x_hat)
31
+ grad_psi_y = jd.fft.pifft3d(grad_psi_y_hat)
32
+ grad_psi_z = jd.fft.pifft3d(grad_psi_z_hat)
33
+
34
+ return grad_psi_x, grad_psi_y, grad_psi_z
35
+
36
+
37
+ def quantum_velocity(psi, box_size, m_per_hbar):
38
+ """
39
+ Compute the velocity from the wave-function
40
+ v = nabla S / m
41
+ psi = sqrt(rho) exp(i S / hbar)
42
+ """
43
+ N = psi.shape[0]
44
+ dx = box_size / N
45
+
46
+ S_per_hbar = jnp.angle(psi)
47
+
48
+ # Central differences with phase unwrapping
49
+ # vx = jnp.roll(S_per_hbar, -1, axis=0) - jnp.roll(S_per_hbar, 1, axis=0)
50
+ # vy = jnp.roll(S_per_hbar, -1, axis=1) - jnp.roll(S_per_hbar, 1, axis=1)
51
+ # vz = jnp.roll(S_per_hbar, -1, axis=2) - jnp.roll(S_per_hbar, 1, axis=2)
52
+ # vx = jnp.where(vx > jnp.pi, vx - 2 * jnp.pi, vx)
53
+ # vx = jnp.where(vx <= -jnp.pi, vx + 2 * jnp.pi, vx)
54
+ # vy = jnp.where(vy > jnp.pi, vy - 2 * jnp.pi, vy)
55
+ # vy = jnp.where(vy <= -jnp.pi, vy + 2 * jnp.pi, vy)
56
+ # vz = jnp.where(vz > jnp.pi, vz - 2 * jnp.pi, vz)
57
+ # vz = jnp.where(vz <= -jnp.pi, vz + 2 * jnp.pi, vz)
58
+ # vx = vx / (2.0 * dx) / m_per_hbar
59
+ # vy = vy / (2.0 * dx) / m_per_hbar
60
+ # vz = vz / (2.0 * dx) / m_per_hbar
61
+
62
+ # Forward differences
63
+ vx1 = jnp.roll(S_per_hbar, -1, axis=0) - S_per_hbar
64
+ vy1 = jnp.roll(S_per_hbar, -1, axis=1) - S_per_hbar
65
+ vz1 = jnp.roll(S_per_hbar, -1, axis=2) - S_per_hbar
66
+ vx1 = jnp.where(vx1 > jnp.pi, vx1 - 2 * jnp.pi, vx1)
67
+ vx1 = jnp.where(vx1 <= -jnp.pi, vx1 + 2 * jnp.pi, vx1)
68
+ vy1 = jnp.where(vy1 > jnp.pi, vy1 - 2 * jnp.pi, vy1)
69
+ vy1 = jnp.where(vy1 <= -jnp.pi, vy1 + 2 * jnp.pi, vy1)
70
+ vz1 = jnp.where(vz1 > jnp.pi, vz1 - 2 * jnp.pi, vz1)
71
+ vz1 = jnp.where(vz1 <= -jnp.pi, vz1 + 2 * jnp.pi, vz1)
72
+ vx1 = vx1 / dx / m_per_hbar
73
+ vy1 = vy1 / dx / m_per_hbar
74
+ vz1 = vz1 / dx / m_per_hbar
75
+
76
+ # Backward differences
77
+ vx2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=0)
78
+ vy2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=1)
79
+ vz2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=2)
80
+ vx2 = jnp.where(vx2 > jnp.pi, vx2 - 2 * jnp.pi, vx2)
81
+ vx2 = jnp.where(vx2 <= -jnp.pi, vx2 + 2 * jnp.pi, vx2)
82
+ vy2 = jnp.where(vy2 > jnp.pi, vy2 - 2 * jnp.pi, vy2)
83
+ vy2 = jnp.where(vy2 <= -jnp.pi, vy2 + 2 * jnp.pi, vy2)
84
+ vz2 = jnp.where(vz2 > jnp.pi, vz2 - 2 * jnp.pi, vz2)
85
+ vz2 = jnp.where(vz2 <= -jnp.pi, vz2 + 2 * jnp.pi, vz2)
86
+ vx2 = vx2 / dx / m_per_hbar
87
+ vy2 = vy2 / dx / m_per_hbar
88
+ vz2 = vz2 / dx / m_per_hbar
89
+
90
+ # Average forward and backward
91
+ vx = 0.5 * (vx1 + vx2)
92
+ vy = 0.5 * (vy1 + vy2)
93
+ vz = 0.5 * (vz1 + vz2)
94
+
95
+ return vx, vy, vz
jaxion/simulation.py CHANGED
@@ -6,11 +6,19 @@ import json
6
6
  import time
7
7
 
8
8
  from .constants import constants
9
- from .quantum import quantum_kick, quantum_drift
9
+ from .quantum import quantum_kick, quantum_drift, quantum_velocity
10
10
  from .gravity import calculate_gravitational_potential
11
11
  from .hydro import hydro_fluxes, hydro_accelerate
12
12
  from .particles import particles_accelerate, particles_drift, bin_particles
13
- from .utils import set_up_parameters, print_parameters
13
+ from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
14
+ from .utils import (
15
+ set_up_parameters,
16
+ print_parameters,
17
+ xmeshgrid,
18
+ xmeshgrid_transpose,
19
+ xzeros,
20
+ xones,
21
+ )
14
22
  from .visualization import plot_sim
15
23
 
16
24
 
@@ -20,11 +28,24 @@ class Simulation:
20
28
 
21
29
  Parameters
22
30
  ----------
23
- params (dict): The Python dictionary that contains the simulation parameters.
31
+ params : dict
32
+ The Python dictionary that contains the simulation parameters.
33
+ Params can also be a string path to a checkpoint directory to load a saved simulation.
34
+ sharding : jax.sharding.NamedSharding, optional
35
+ jax sharding used for distributed (multi-GPU) simulations
24
36
 
25
37
  """
26
38
 
27
- def __init__(self, params):
39
+ def __init__(self, params, sharding=None):
40
+ # allow loading directly from a checkpoint path
41
+ load_from_checkpoint = False
42
+ checkpoint_dir = ""
43
+ if isinstance(params, str):
44
+ load_from_checkpoint = True
45
+ checkpoint_dir = os.path.join(os.getcwd(), params)
46
+ with open(os.path.join(checkpoint_dir, "params.json"), "r") as f:
47
+ params = json.load(f)
48
+
28
49
  # start from default simulation parameters and update with user params
29
50
  self._params = set_up_parameters(params)
30
51
 
@@ -32,40 +53,85 @@ class Simulation:
32
53
  if self.resolution % 2 != 0:
33
54
  raise ValueError("Resolution must be divisible by 2.")
34
55
 
56
+ if self.params["time"]["adaptive"]:
57
+ raise NotImplementedError("Adaptive time stepping is not yet implemented.")
58
+
59
+ if self.params["physics"]["cosmology"]:
60
+ if (
61
+ self.params["physics"]["hydro"]
62
+ or self.params["physics"]["particles"]
63
+ or self.params["physics"]["external_potential"]
64
+ ):
65
+ raise NotImplementedError(
66
+ "Cosmological hydro/particles/external_potential physics is not yet implemented."
67
+ )
68
+
69
+ if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
70
+ if sharding is not None:
71
+ raise NotImplementedError(
72
+ "hydro/particles sharding is not yet implemented."
73
+ )
74
+
35
75
  if self.params["output"]["save"]:
36
- print("Simulation initialized with parameters:")
37
- print_parameters(self.params)
76
+ if jax.process_index() == 0:
77
+ print("Simulation parameters:")
78
+ print_parameters(self.params)
79
+
80
+ # jitted functions
81
+ self.xmeshgrid_jit = jax.jit(
82
+ xmeshgrid, in_shardings=None, out_shardings=sharding
83
+ )
84
+ self.xmeshgrid_transpose_jit = jax.jit(
85
+ xmeshgrid_transpose, in_shardings=None, out_shardings=sharding
86
+ )
87
+ self.xzeros_jit = jax.jit(
88
+ xzeros, static_argnums=0, in_shardings=None, out_shardings=sharding
89
+ )
90
+ self.xones_jit = jax.jit(
91
+ xones, static_argnums=0, in_shardings=None, out_shardings=sharding
92
+ )
38
93
 
39
94
  # simulation state
40
95
  self.state = {}
41
96
  self.state["t"] = 0.0
97
+ if self.params["physics"]["cosmology"]:
98
+ self.state["redshift"] = 0.0
42
99
  if self.params["physics"]["quantum"]:
43
- self.state["psi"] = (
44
- jnp.zeros((self.resolution, self.resolution, self.resolution)) * 1j
45
- )
100
+ self.state["psi"] = self.xzeros_jit(self.resolution) * 1j
46
101
  if self.params["physics"]["external_potential"]:
47
- self.state["V_ext"] = jnp.zeros(
48
- (self.resolution, self.resolution, self.resolution)
49
- )
102
+ self.state["V_ext"] = self.xzeros_jit(self.resolution)
50
103
  if self.params["physics"]["hydro"]:
51
104
  self.state["rho"] = jnp.zeros(
52
- (self.resolution, self.resolution, self.resolution)
105
+ (self.resolution, self.resolution, self.resolution),
53
106
  )
54
107
  self.state["vx"] = jnp.zeros(
55
- (self.resolution, self.resolution, self.resolution)
108
+ (self.resolution, self.resolution, self.resolution),
56
109
  )
57
110
  self.state["vy"] = jnp.zeros(
58
- (self.resolution, self.resolution, self.resolution)
111
+ (self.resolution, self.resolution, self.resolution),
59
112
  )
60
113
  self.state["vz"] = jnp.zeros(
61
- (self.resolution, self.resolution, self.resolution)
114
+ (self.resolution, self.resolution, self.resolution),
62
115
  )
63
116
  if self.params["physics"]["particles"]:
64
117
  self.state["pos"] = jnp.zeros((self.num_particles, 3))
65
118
  self.state["vel"] = jnp.zeros((self.num_particles, 3))
66
119
 
120
+ if load_from_checkpoint:
121
+ options = ocp.CheckpointManagerOptions()
122
+ async_checkpoint_manager = ocp.CheckpointManager(
123
+ checkpoint_dir, options=options
124
+ )
125
+ step = async_checkpoint_manager.latest_step()
126
+ self.state = async_checkpoint_manager.restore(
127
+ step, args=ocp.args.StandardRestore(self.state)
128
+ )
129
+
67
130
  @property
68
131
  def resolution(self):
132
+ """
133
+ Return the (linear) resolution of the simulation
134
+ """
69
135
  return (
70
136
  self.params["domain"]["resolution_base"]
71
137
  * self.params["domain"]["resolution_multiplier"]
@@ -73,18 +139,30 @@ class Simulation:
73
139
 
74
140
  @property
75
141
  def num_particles(self):
142
+ """
143
+ Return the number of particles in the simulation
144
+ """
76
145
  return self.params["particles"]["num_particles"]
77
146
 
78
147
  @property
79
148
  def box_size(self):
149
+ """
150
+ Return the box size of the simulation (kpc)
151
+ """
80
152
  return self.params["domain"]["box_size"]
81
153
 
82
154
  @property
83
155
  def dx(self):
156
+ """
157
+ Return the cell size size of the simulation (kpc)
158
+ """
84
159
  return self.box_size / self.resolution
85
160
 
86
161
  @property
87
162
  def axion_mass(self):
163
+ """
164
+ Return the axion particle mass in the simulation (M_sun)
165
+ """
88
166
  return (
89
167
  self.params["quantum"]["m_22"]
90
168
  * 1.0e-22
@@ -94,29 +172,49 @@ class Simulation:
94
172
 
95
173
  @property
96
174
  def sound_speed(self):
175
+ """
176
+ Return the isothermal gas sound speed in the simulation (km/s)
177
+ """
97
178
  return self.params["hydro"]["sound_speed"]
98
179
 
99
180
  @property
100
181
  def params(self):
182
+ """
183
+ Return the parameters of the simulation
184
+ """
101
185
  return self._params
102
186
 
103
187
  @property
104
188
  def grid(self):
189
+ """
190
+ Return the simulation grid
191
+ """
105
192
  hx = 0.5 * self.dx
106
193
  x_lin = jnp.linspace(hx, self.box_size - hx, self.resolution)
107
- X, Y, Z = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
108
- return X, Y, Z
194
+ xx, yy, zz = self.xmeshgrid_jit(x_lin)
195
+ return xx, yy, zz
109
196
 
110
197
  @property
111
198
  def kgrid(self):
199
+ """
200
+ Return the simulation spectral grid
201
+ """
112
202
  nx = self.resolution
113
203
  k_lin = (2.0 * jnp.pi / self.box_size) * jnp.arange(-nx / 2, nx / 2)
114
- kx, ky, kz = jnp.meshgrid(k_lin, k_lin, k_lin, indexing="ij")
204
+ kx, ky, kz = self.xmeshgrid_transpose_jit(k_lin)
115
205
  kx = jnp.fft.ifftshift(kx)
116
206
  ky = jnp.fft.ifftshift(ky)
117
207
  kz = jnp.fft.ifftshift(kz)
118
208
  return kx, ky, kz
119
209
 
210
+ @property
211
+ def quantum_velocity(self):
212
+ """
213
+ Return the dark matter velocity field from the wavefunction
214
+ """
215
+ m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
216
+ return quantum_velocity(self.state["psi"], self.box_size, m_per_hbar)
217
+
120
218
  def _calc_rho_bar(self, state):
121
219
  rho_bar = 0.0
122
220
  if self.params["physics"]["quantum"]:
@@ -133,7 +231,7 @@ class Simulation:
133
231
  def _calc_grav_potential(self, state, k_sq):
134
232
  G = constants["gravitational_constant"]
135
233
  m_particle = self.params["particles"]["particle_mass"]
136
- rho_bar = self._calc_rho_bar(self.state)
234
+ rho_bar = self._calc_rho_bar(state)
137
235
  rho_tot = 0.0
138
236
  if self.params["physics"]["quantum"]:
139
237
  rho_tot += jnp.abs(state["psi"]) ** 2
@@ -141,10 +239,17 @@ class Simulation:
141
239
  rho_tot += state["rho"]
142
240
  if self.params["physics"]["particles"]:
143
241
  rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
242
+ if self.params["physics"]["cosmology"]:
243
+ scale_factor = 1.0 / (1.0 + state["redshift"])
244
+ rho_bar *= scale_factor
245
+ rho_tot *= scale_factor
144
246
  return calculate_gravitational_potential(rho_tot, k_sq, G, rho_bar)
145
247
 
146
248
  @property
147
249
  def potential(self):
250
+ """
251
+ Return the gravitational potential
252
+ """
148
253
  kx, ky, kz = self.kgrid
149
254
  k_sq = kx**2 + ky**2 + kz**2
150
255
  return self._calc_grav_potential(self.state, k_sq)
@@ -167,19 +272,40 @@ class Simulation:
167
272
  # Simulation parameters
168
273
  dx = self.dx
169
274
  m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
275
+ box_size = self.box_size
170
276
 
171
277
  dt_fac = 1.0
172
278
  dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
279
+ t_start = self.params["time"]["start"]
173
280
  t_end = self.params["time"]["end"]
281
+ t_span = t_end - t_start
282
+ state["t"] = t_start
283
+
284
+ # cosmology
285
+ if self.params["physics"]["cosmology"]:
286
+ z_start = self.params["time"]["start"]
287
+ z_end = self.params["time"]["end"]
288
+ omega_matter = self.params["cosmology"]["omega_matter"]
289
+ omega_lambda = self.params["cosmology"]["omega_lambda"]
290
+ little_h = self.params["cosmology"]["little_h"]
291
+ t_span = get_supercomoving_time_interval(
292
+ z_start, z_end, omega_matter, omega_lambda, little_h
293
+ )
294
+ state["t"] = 0.0
295
+ state["redshift"] = z_start
174
296
 
297
+ # hydro
175
298
  c_sound = self.params["hydro"]["sound_speed"]
176
- box_size = self.box_size
177
299
 
178
300
  # round up to the nearest multiple of num_checkpoints
179
301
  num_checkpoints = self.params["output"]["num_checkpoints"]
180
- nt = int(round(round(t_end / dt_kin) / num_checkpoints) * num_checkpoints)
302
+ nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
181
303
  nt_sub = int(round(nt / num_checkpoints))
182
- dt = t_end / nt
304
+ dt = t_span / nt
305
+
306
+ # distributed arrays (fixed) needed for calculations
307
+ kx, ky, kz = None, None, None
308
+ k_sq = None
183
309
 
184
310
  # Fourier space variables
185
311
  if self.params["physics"]["gravity"] or self.params["physics"]["quantum"]:
@@ -192,10 +318,14 @@ class Simulation:
192
318
  checkpoint_dir = checkpoint_dir = os.path.join(
193
319
  os.getcwd(), self.params["output"]["path"]
194
320
  )
195
- path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
321
+ path = os.path.join(os.getcwd(), checkpoint_dir)
322
+ if jax.process_index() == 0:
323
+ path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
196
324
  async_checkpoint_manager = ocp.CheckpointManager(path, options=options)
197
325
 
198
- def _kick(state, dt):
326
+ carry = (state, kx, ky, kz, k_sq)
327
+
328
+ def _kick(state, kx, ky, kz, k_sq, dt):
199
329
  # Kick (half-step)
200
330
  if (
201
331
  self.params["physics"]["gravity"]
@@ -222,7 +352,7 @@ class Simulation:
222
352
  state["vel"], state["pos"], V, kx, ky, kz, dx, dt
223
353
  )
224
354
 
225
- def _drift(state, dt):
355
+ def _drift(state, k_sq, dt):
226
356
  # Drift (full-step)
227
357
  if self.params["physics"]["quantum"]:
228
358
  state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
@@ -234,20 +364,30 @@ class Simulation:
234
364
  state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
235
365
 
236
366
  @jax.jit
237
- def _update(_, state):
367
+ def _update(_, carry):
238
368
  # Update the simulation state by one timestep
239
369
  # according to a 2nd-order `kick-drift-kick` scheme
240
- _kick(state, 0.5 * dt)
241
- _drift(state, dt)
242
- _kick(state, 0.5 * dt)
243
-
244
- # update time
370
+ state, kx, ky, kz, k_sq = carry
371
+ _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
372
+ _drift(state, k_sq, dt)
373
+ # update time & redshift
245
374
  state["t"] += dt
375
+ if self.params["physics"]["cosmology"]:
376
+ scale_factor = get_next_scale_factor(
377
+ state["redshift"],
378
+ dt,
379
+ self.params["cosmology"]["omega_matter"],
380
+ self.params["cosmology"]["omega_lambda"],
381
+ self.params["cosmology"]["little_h"],
382
+ )
383
+ state["redshift"] = 1.0 / scale_factor - 1.0
384
+ _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
246
385
 
247
- return state
386
+ return state, kx, ky, kz, k_sq
248
387
 
249
388
  # save initial state
250
- print("Starting simulation ...")
389
+ if jax.process_index() == 0:
390
+ print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
251
391
  if self.params["output"]["save"]:
252
392
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
253
393
  json.dump(self.params, f, indent=2)
@@ -259,7 +399,8 @@ class Simulation:
259
399
  t_start_timer = time.time()
260
400
  if self.params["output"]["save"]:
261
401
  for i in range(1, num_checkpoints + 1):
262
- state = jax.lax.fori_loop(0, nt_sub, _update, init_val=state)
402
+ carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
403
+ state, kx, ky, kz, k_sq = carry
263
404
  jax.block_until_ready(state)
264
405
  # save state
265
406
  async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
@@ -267,18 +408,23 @@ class Simulation:
267
408
  elapsed = time.time() - t_start_timer
268
409
  est_total = elapsed / i * num_checkpoints
269
410
  est_remaining = est_total - elapsed
270
- print(
271
- f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
272
- )
411
+ if jax.process_index() == 0:
412
+ print(
413
+ f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
414
+ )
273
415
  plot_sim(state, checkpoint_dir, i, self.params)
274
416
  async_checkpoint_manager.wait_until_finished()
275
417
  else:
276
418
  state = jax.lax.fori_loop(0, nt, _update, init_val=state)
277
419
  jax.block_until_ready(state)
278
- print("Simulation Run Time (s): ", time.time() - t_start_timer)
420
+ if jax.process_index() == 0:
421
+ print("Simulation Run Time (s): ", time.time() - t_start_timer)
279
422
 
280
423
  return state
281
424
 
282
425
  def run(self):
426
+ """
427
+ Run the simulation
428
+ """
283
429
  self.state = self._evolve(self.state)
284
430
  jax.block_until_ready(self.state)
jaxion/utils.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  import importlib.resources
6
6
  import json
7
7
  from importlib.metadata import version
8
+ import jax.numpy as jnp
8
9
 
9
10
 
10
11
  def print_parameters(params):
@@ -90,3 +91,29 @@ def run_example_main(example_path, argv=None):
90
91
  os.chdir(old_cwd)
91
92
  sys.argv = old_argv
92
93
  return result
94
+
95
+
96
+ # Make a distributed meshgrid function
97
+ def xmeshgrid(x_lin):
98
+ xx, yy, zz = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
99
+ return xx, yy, zz
100
+
101
+
102
+ # NOTE: jaxdecomp (jd) has pfft3d transpose the axis (X, Y, Z) --> (Y, Z, X), and pifft3d undo it
103
+ # so the fourier space variables (e.g. kx, ky, kz) all need to be transposed
104
+
105
+
106
+ def xmeshgrid_transpose(x_lin):
107
+ xx, yy, zz = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
108
+ xx = jnp.transpose(xx, (1, 2, 0))
109
+ yy = jnp.transpose(yy, (1, 2, 0))
110
+ zz = jnp.transpose(zz, (1, 2, 0))
111
+ return xx, yy, zz
112
+
113
+
114
+ def xzeros(nx):
115
+ return jnp.zeros((nx, nx, nx))
116
+
117
+
118
+ def xones(nx):
119
+ return jnp.ones((nx, nx, nx))
jaxion/visualization.py CHANGED
@@ -1,3 +1,4 @@
1
+ import jax
1
2
  import jax.numpy as jnp
2
3
  import matplotlib.pyplot as plt
3
4
  import os
@@ -8,67 +9,82 @@ def plot_sim(state, checkpoint_dir, i, params):
8
9
 
9
10
  dynamic_range = params["output"]["plot_dynamic_range"]
10
11
 
12
+ # process distributed data
11
13
  if params["physics"]["quantum"]:
12
- plt.clf()
13
-
14
- # DM projection
15
14
  nx = state["psi"].shape[0]
16
- rho_bar = jnp.mean(jnp.abs(state["psi"]) ** 2)
17
- vmin = jnp.log10(rho_bar / dynamic_range)
18
- vmax = jnp.log10(rho_bar * dynamic_range)
15
+ rho_bar_dm = jnp.mean(jnp.abs(state["psi"]) ** 2)
16
+ rho_proj_dm = jax.experimental.multihost_utils.process_allgather(
17
+ jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2).T)
18
+ ).reshape(nx, nx)
19
+ if params["physics"]["hydro"]:
20
+ nx = state["rho"].shape[0]
21
+ rho_bar_gas = jnp.mean(state["rho"])
22
+ rho_proj_gas = jax.experimental.multihost_utils.process_allgather(
23
+ jnp.log10(jnp.mean(state["rho"], axis=2).T)
24
+ ).reshape(nx, nx)
19
25
 
20
- rho_proj_dm = jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2)).T
21
- ax = plt.gca()
22
- ax.imshow(
23
- rho_proj_dm,
24
- cmap="inferno",
25
- origin="lower",
26
- vmin=vmin,
27
- vmax=vmax,
28
- extent=[0, nx, 0, nx],
29
- )
30
- if params["physics"]["particles"]:
31
- # draw particles
32
- box_size = params["domain"]["box_size"]
33
- sx = (state["pos"][:, 0] / box_size) * nx
34
- sy = (state["pos"][:, 1] / box_size) * nx
35
- plt.plot(sx, sy, color="cyan", marker=".", linestyle="None", markersize=2)
36
- ax.set_aspect("equal")
37
- ax.get_xaxis().set_visible(False)
38
- ax.get_yaxis().set_visible(False)
26
+ # create plot on process 0
27
+ if jax.process_index() == 0:
28
+ if params["physics"]["quantum"]:
29
+ plt.clf()
39
30
 
40
- plt.savefig(
41
- os.path.join(checkpoint_dir, f"dm{i:03d}.png"),
42
- bbox_inches="tight",
43
- pad_inches=0,
44
- )
45
- plt.close()
31
+ # DM projection
32
+ nx = state["psi"].shape[0]
33
+ vmin = jnp.log10(rho_bar_dm / dynamic_range)
34
+ vmax = jnp.log10(rho_bar_dm * dynamic_range)
46
35
 
47
- if params["physics"]["hydro"]:
48
- plt.clf()
36
+ ax = plt.gca()
37
+ ax.imshow(
38
+ rho_proj_dm,
39
+ cmap="inferno",
40
+ origin="lower",
41
+ vmin=vmin,
42
+ vmax=vmax,
43
+ extent=[0, nx, 0, nx],
44
+ )
45
+ if params["physics"]["particles"]:
46
+ # draw particles
47
+ box_size = params["domain"]["box_size"]
48
+ sx = (state["pos"][:, 0] / box_size) * nx
49
+ sy = (state["pos"][:, 1] / box_size) * nx
50
+ plt.plot(
51
+ sx, sy, color="cyan", marker=".", linestyle="None", markersize=5
52
+ )
53
+ ax.set_aspect("equal")
54
+ ax.get_xaxis().set_visible(False)
55
+ ax.get_yaxis().set_visible(False)
49
56
 
50
- # gas projection
51
- nx = state["rho"].shape[0]
52
- rho_bar = jnp.mean(state["rho"])
53
- vmin = jnp.log10(rho_bar / dynamic_range)
54
- vmax = jnp.log10(rho_bar * dynamic_range)
55
- rho_proj_gas = jnp.log10(jnp.mean(state["rho"], axis=2)).T
56
- ax = plt.gca()
57
- ax.imshow(
58
- rho_proj_gas,
59
- cmap="viridis",
60
- origin="lower",
61
- vmin=vmin,
62
- vmax=vmax,
63
- extent=[0, nx, 0, nx],
64
- )
65
- ax.set_aspect("equal")
66
- ax.get_xaxis().set_visible(False)
67
- ax.get_yaxis().set_visible(False)
57
+ plt.savefig(
58
+ os.path.join(checkpoint_dir, f"dm{i:03d}.png"),
59
+ bbox_inches="tight",
60
+ pad_inches=0,
61
+ )
62
+ plt.close()
63
+
64
+ if params["physics"]["hydro"]:
65
+ plt.clf()
66
+
67
+ # gas projection
68
+ nx = state["rho"].shape[0]
69
+ vmin = jnp.log10(rho_bar_gas / dynamic_range)
70
+ vmax = jnp.log10(rho_bar_gas * dynamic_range)
71
+
72
+ ax = plt.gca()
73
+ ax.imshow(
74
+ rho_proj_gas,
75
+ cmap="viridis",
76
+ origin="lower",
77
+ vmin=vmin,
78
+ vmax=vmax,
79
+ extent=[0, nx, 0, nx],
80
+ )
81
+ ax.set_aspect("equal")
82
+ ax.get_xaxis().set_visible(False)
83
+ ax.get_yaxis().set_visible(False)
68
84
 
69
- plt.savefig(
70
- os.path.join(checkpoint_dir, f"gas{i:03d}.png"),
71
- bbox_inches="tight",
72
- pad_inches=0,
73
- )
74
- plt.close()
85
+ plt.savefig(
86
+ os.path.join(checkpoint_dir, f"gas{i:03d}.png"),
87
+ bbox_inches="tight",
88
+ pad_inches=0,
89
+ )
90
+ plt.close()
@@ -0,0 +1,133 @@
1
+ Metadata-Version: 2.4
2
+ Name: jaxion
3
+ Version: 0.0.4
4
+ Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
+ Author-email: Philip Mocz <philip.mocz@gmail.com>
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Documentation, https://jaxion.readthedocs.io
8
+ Project-URL: Homepage, https://github.com/JaxionProject/jaxion
9
+ Requires-Python: >=3.11
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: jax==0.5.3
13
+ Requires-Dist: jaxdecomp==0.2.7
14
+ Requires-Dist: tensorflow
15
+ Requires-Dist: orbax-checkpoint==0.11.18
16
+ Requires-Dist: optax==0.2.5
17
+ Requires-Dist: numpy
18
+ Requires-Dist: matplotlib
19
+ Requires-Dist: setuptools>=70.1.1
20
+ Provides-Extra: cuda12
21
+ Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
22
+ Dynamic: license-file
23
+
24
+ <p align="center">
25
+ <a href="https://jaxion.readthedocs.io">
26
+ <img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
27
+ </a>
28
+ </p>
29
+
30
+ # jaxion
31
+
32
+ [![Repo Status][status-badge]][status-link]
33
+ [![PyPI Version Status][pypi-badge]][pypi-link]
34
+ [![Test Status][workflow-test-badge]][workflow-test-link]
35
+ [![Coverage][coverage-badge]][coverage-link]
36
+ [![Readthedocs Status][docs-badge]][docs-link]
37
+ [![License][license-badge]][license-link]
38
+
39
+ [status-link]: https://www.repostatus.org/#active
40
+ [status-badge]: https://www.repostatus.org/badges/latest/active.svg
41
+ [pypi-link]: https://pypi.org/project/jaxion
42
+ [pypi-badge]: https://img.shields.io/pypi/v/jaxion?label=PyPI&logo=pypi
43
+ [workflow-test-link]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml
44
+ [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
45
+ [coverage-link]: https://app.codecov.io/gh/JaxionProject/jaxion
46
+ [coverage-badge]: https://codecov.io/github/jaxionproject/jaxion/graph/jaxion-server/badge.svg
47
+ [docs-link]: https://jaxion.readthedocs.io
48
+ [docs-badge]: https://readthedocs.org/projects/jaxion/badge
49
+ [license-link]: https://opensource.org/licenses/Apache-2.0
50
+ [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
51
+
52
+ A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
53
+
54
+ Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
55
+
56
+ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
57
+
58
+ Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
59
+
60
+
61
+ ## Getting started
62
+
63
+ Install with:
64
+
65
+ ```console
66
+ pip install jaxion
67
+ ```
68
+
69
+ or, for GPU support, use:
70
+
71
+ ```console
72
+ pip install jaxion[cuda12]
73
+ ```
74
+
75
+ See the docs for more info on how to [build from source](https://jaxion.readthedocs.io/en/latest/pages/installation.html).
76
+
77
+
78
+ ## Examples
79
+
80
+ Check out the `examples/` directory for demonstrations of using Jaxion.
81
+
82
+ <p align="center">
83
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
84
+ <img src="examples/cosmological_box/movie.gif" alt="cosmological_box" width="128"/>
85
+ </a>
86
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
87
+ <img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
88
+ </a>
89
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
90
+ <img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
91
+ </a>
92
+ <br>
93
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_stars">
94
+ <img src="examples/heating_stars/movie.gif" alt="heating_stars" width="128"/>
95
+ </a>
96
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/kinetic_condensation">
97
+ <img src="examples/kinetic_condensation/movie.gif" alt="kinetic_condensation" width="128"/>
98
+ </a>
99
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/logo">
100
+ <img src="examples/logo/movie.gif" alt="logo" width="128"/>
101
+ </a>
102
+ <br>
103
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger">
104
+ <img src="examples/soliton_binary_merger/movie.gif" alt="soliton_binary_merger" width="128"/>
105
+ </a>
106
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_merger">
107
+ <img src="examples/soliton_merger/movie.gif" alt="soliton_merger" width="128"/>
108
+ </a>
109
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
110
+ <img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
111
+ </a>
112
+ </p>
113
+
114
+
115
+ ## Links
116
+
117
+ * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
118
+ * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and using jaxion.
119
+
120
+
121
+ ## Testing
122
+
123
+ Jaxion is tested with `pytest`. Tests are included in the `tests/` folder.
124
+
125
+
126
+ ## Contributing
127
+
128
+ Jaxion welcomes community contributions of all kinds. Open an issue or fork the code and submit a Pull request. Please check out the [Contributing Guidelines](CONTRIBUTING.md)
129
+
130
+
131
+ ## Cite this repository
132
+
133
+ TODO XXX
@@ -0,0 +1,17 @@
1
+ jaxion/__init__.py,sha256=Hdji1UQ47lG24Pqcy6UUq9L0-qy6m9Ax41L0vIYzBto,164
2
+ jaxion/analysis.py,sha256=4YT9Z2dkFoXwft3fQM1HyynVPlIdtRd80VtI2vWTyq4,1568
3
+ jaxion/constants.py,sha256=HyY2ktKQakv78jD1yQvFdM3sklUJcPgDMYlTsSPQTxI,512
4
+ jaxion/cosmology.py,sha256=UC1McXNTXGoPRYXn0nI2-csVkJWL-ZBNoCa44oU1b4w,2681
5
+ jaxion/gravity.py,sha256=3brRZelKm-soXqk_Lt3SqhbZ00woJCraqwdMuR-KooA,291
6
+ jaxion/hydro.py,sha256=KoJ02tRpAc4V3Ofzw4zbHLRaE2GdIatbOBE04_LsSRw,6980
7
+ jaxion/params_default.json,sha256=9CJrhEPsv5zGEs7_WqFyuccCDipPCDhXgKzVdqOsOWE,2775
8
+ jaxion/particles.py,sha256=pMopGvoZ0J_3EviD0WnTMmiebU9h2_8IO-p6I-E5DEU,3980
9
+ jaxion/quantum.py,sha256=GWOpN6ipfEw-6Ah2zQpxS3oqeSt_iHMDSgnVYSjXY5E,3321
10
+ jaxion/simulation.py,sha256=s6gCAt-gAoN5d46vcdxoqtn4TwsrfNGb4Cq-2p_JxsI,15927
11
+ jaxion/utils.py,sha256=rT7NM0FNEgFwN7oTgTb-jkR66Iw0xYTHHxcoikYd1ag,3572
12
+ jaxion/visualization.py,sha256=K5EQOHPfj7LF29fW_naWH8a7TEyEa3wIaQw7rpebx0w,2914
13
+ jaxion-0.0.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
+ jaxion-0.0.4.dist-info/METADATA,sha256=66lU0x1ZofP-uCwD4hF0C45Ao-pNhPvMzr6Krs__Hws,5305
15
+ jaxion-0.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ jaxion-0.0.4.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
17
+ jaxion-0.0.4.dist-info/RECORD,,
@@ -1,76 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: jaxion
3
- Version: 0.0.2
4
- Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
- Author-email: Philip Mocz <philip.mocz@gmail.com>
6
- License-Expression: Apache-2.0
7
- Project-URL: Documentation, https://jaxion.readthedocs.io
8
- Project-URL: Homepage, https://github.com/JaxionProject/jaxion
9
- Requires-Python: >=3.11
10
- Description-Content-Type: text/markdown
11
- License-File: LICENSE
12
- Requires-Dist: jax==0.5.3
13
- Requires-Dist: jaxdecomp==0.2.7
14
- Requires-Dist: tensorflow
15
- Requires-Dist: orbax-checkpoint==0.11.18
16
- Requires-Dist: optax==0.2.5
17
- Requires-Dist: numpy
18
- Requires-Dist: matplotlib
19
- Requires-Dist: setuptools>=70.1.1
20
- Provides-Extra: cuda12
21
- Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
22
- Dynamic: license-file
23
-
24
- # jaxion
25
-
26
- [![Repo Status][status-badge]][status-link]
27
- [![PyPI Version Status][pypi-badge]][pypi-link]
28
- [![Test Status][workflow-test-badge]][workflow-test-link]
29
- [![Readthedocs Status][docs-badge]][docs-link]
30
- [![License][license-badge]][license-link]
31
-
32
- [status-link]: https://www.repostatus.org/#active
33
- [status-badge]: https://www.repostatus.org/badges/latest/active.svg
34
- [pypi-link]: https://pypi.org/project/jaxion
35
- [pypi-badge]: https://img.shields.io/pypi/v/jaxion?label=PyPI&logo=pypi
36
- [workflow-test-link]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml
37
- [workflow-test-badge]: https://github.com/JaxionProject/jaxion/actions/workflows/test-package.yml/badge.svg?event=push
38
- [docs-link]: https://jaxion.readthedocs.io
39
- [docs-badge]: https://readthedocs.org/projects/jaxion/badge
40
- [license-link]: https://opensource.org/licenses/Apache-2.0
41
- [license-badge]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
42
-
43
- A simple JAX-powered simulation library for numerical experiments of fuzzy dark matter, stars, gas + more!
44
-
45
- Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
46
-
47
- ⚠️ Jaxion is currently being developed and is not yet ready for use. Check back later ⚠️
48
-
49
- Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seemlessly integrate with piplines for inverse-problems, inference, optimization, and coupling to ML models.
50
-
51
-
52
- ## Getting started
53
-
54
- Install with:
55
-
56
- ```console
57
- pip install jaxion
58
- ```
59
-
60
- or, for GPU support use:
61
-
62
- ```console
63
- pip install jaxion[cuda12]
64
- ```
65
-
66
- Check out the `examples/` directory for demonstrations of using Jaxion.
67
-
68
-
69
- ## Links
70
-
71
- * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
72
- * [Documentation](https://jaxion.readthedocs.io) for up-to-date information about installing and running jaxion.
73
-
74
-
75
- ## Cite this repository
76
-
@@ -1,15 +0,0 @@
1
- jaxion/__init__.py,sha256=gHrU_SvIBIbuYt-MeEeyiMyuJKwLjTTiz513OZ7CDHQ,95
2
- jaxion/constants.py,sha256=b0vgqtezIEGp4O5JRJo3YK3eQC8AN6hmV84-4pYWgmc,528
3
- jaxion/gravity.py,sha256=DdVOhAaoIPLXtW7ALrMK2umLPkStzxJ778l1SD2yMkY,266
4
- jaxion/hydro.py,sha256=3ZFNDhIaPBw4uTBqFbjJotCLkuKG8AvRcY1T78QQ96U,6953
5
- jaxion/params_default.json,sha256=4-XNPN52phGxYsP6FOeIJH0ALtICjrQxApotz-NhJEc,1782
6
- jaxion/particles.py,sha256=7oJBVyddfOXUDYxQuGIs8_ByL2Hu6jkuWE1MWtshFjk,3953
7
- jaxion/quantum.py,sha256=91tEotQhX1CB3HCaRCDARRPX_dEx58GIUCluY5RyPvA,377
8
- jaxion/simulation.py,sha256=KcIi91LHJO5K4M0gxU4qc7e4JEeOyW1HUZOYwt8FOG0,10385
9
- jaxion/utils.py,sha256=LCVx5_N9-6HsbENENOEz8tI6Czl0Z8Bit2icf1k4024,2880
10
- jaxion/visualization.py,sha256=WKID2_LFbbqXwmW_TLTnVznKYHfjiDKNbjD9xS4GfLc,2209
11
- jaxion-0.0.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
- jaxion-0.0.2.dist-info/METADATA,sha256=uUe69xVa678F2uV28-CczduNdfoOtBpwhfuM_mJKXpM,2811
13
- jaxion-0.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
- jaxion-0.0.2.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
15
- jaxion-0.0.2.dist-info/RECORD,,
File without changes