jaxion 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxion/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  from .simulation import Simulation as Simulation
2
2
  from .constants import constants as constants
3
+ from .analysis import radial_power_spectrum as radial_power_spectrum
jaxion/analysis.py ADDED
@@ -0,0 +1,43 @@
1
+ import jax.numpy as jnp
2
+ import jaxdecomp as jd
3
+
4
+
5
+ def radial_power_spectrum(data_cube, kx, ky, kz, box_size):
6
+ """
7
+ Computes radially averaged power spectral density of data_cube (3D).
8
+ data_cube: jnp.ndarray (3D, must be cubic)
9
+ boxsize: float (physical size of box)
10
+ Returns: Pf (radial power spectrum), k (wavenumbers), total_power
11
+ """
12
+ dim = data_cube.ndim
13
+ nx = data_cube.shape[0]
14
+ dx = box_size / nx
15
+
16
+ # Compute power spectrum
17
+ data_cube_hat = jd.fft.pfft3d(data_cube)
18
+ total_power = 0.5 * jnp.sum(jnp.abs(data_cube_hat) ** 2) / nx**dim * dx**dim
19
+ phi_k = 0.5 * jnp.abs(data_cube_hat) ** 2 / nx**dim * dx**dim
20
+ half_size = nx // 2 + 1
21
+
22
+ # Compute radially-averaged power spectrum
23
+ # if dim == 2:
24
+ # k_r = jnp.sqrt(kx**2 + ky**2)
25
+ k_r = jnp.sqrt(kx**2 + ky**2 + kz**2)
26
+
27
+ Pf, _ = jnp.histogram(
28
+ k_r, range=(-0.5, half_size - 0.5), bins=half_size, weights=phi_k
29
+ )
30
+ norm, _ = jnp.histogram(k_r, range=(-0.5, half_size - 0.5), bins=half_size)
31
+ Pf /= norm + (norm == 0)
32
+
33
+ k = 2.0 * jnp.pi * jnp.arange(half_size) / box_size
34
+ dk = 2.0 * jnp.pi / box_size
35
+
36
+ Pf /= dk**dim
37
+
38
+ # Add geometrical factor
39
+ # if dim == 2:
40
+ # Pf = Pf * 2.0 * jnp.pi * k
41
+ Pf *= 4.0 * jnp.pi * k**2
42
+
43
+ return Pf, k, total_power
jaxion/cosmology.py ADDED
@@ -0,0 +1,74 @@
1
+ import jax.numpy as jnp
2
+
3
+ # Pure functions for cosmology simulation
4
+
5
+
6
+ def get_physical_time_interval(z_start, z_end, omega_matter, omega_lambda, little_h):
7
+ """Compute the total physical time between two redshifts."""
8
+ # da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
9
+ a_start = 1.0 / (1.0 + z_start)
10
+ a_end = 1.0 / (1.0 + z_end)
11
+ H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
12
+ n_quad = 10000
13
+ a_lin = jnp.linspace(a_start, a_end, n_quad)
14
+ a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
15
+ dt_hat = jnp.trapezoid(1.0 / a_dot, a_lin)
16
+ return dt_hat
17
+
18
+
19
+ def get_supercomoving_time_interval(
20
+ z_start, z_end, omega_matter, omega_lambda, little_h
21
+ ):
22
+ """Compute the total supercomoving time (dt_hat = a^-2 dt) between two redshifts."""
23
+ # da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
24
+ # da/dt_hat = a^2 * da/dt
25
+ # dt_hat/da = a^-2 / (da/dt)
26
+ a_start = 1.0 / (1.0 + z_start)
27
+ a_end = 1.0 / (1.0 + z_end)
28
+ H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
29
+ n_quad = 10000
30
+ a_lin = jnp.linspace(a_start, a_end, n_quad)
31
+ a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
32
+ dt_hat_da = a_lin**-2 / a_dot
33
+ dt_hat = jnp.trapezoid(dt_hat_da, a_lin)
34
+ return dt_hat
35
+
36
+
37
+ def get_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
38
+ """Compute the scale factor corresponding to a given supercomoving time,
39
+ by root finding."""
40
+ a_start = 1.0 / (1.0 + z_start)
41
+ a = a_start
42
+ tolerance = 1e-6
43
+ max_iterations = 100
44
+ lower_bound = 0.0
45
+ upper_bound = 2.0 # Set a reasonable upper bound for the scale factor
46
+ for _ in range(max_iterations):
47
+ dt_hat_guess = get_supercomoving_time_interval(
48
+ z_start, 1.0 / a - 1.0, omega_matter, omega_lambda, little_h
49
+ )
50
+ error = dt_hat_guess - dt_hat
51
+ if jnp.abs(error) < tolerance:
52
+ break
53
+ # Use bisection method
54
+ if error > 0:
55
+ upper_bound = a # Move upper bound down
56
+ else:
57
+ lower_bound = a # Move lower bound up
58
+ a = (lower_bound + upper_bound) / 2 # Update scale factor to midpoint
59
+ return a
60
+
61
+
62
+ def get_next_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
63
+ """Advance scale factor by dt_hat using RK2 (midpoint) on da/dt_hat = a^2 * da/dt."""
64
+ a_start = 1.0 / (1.0 + z_start)
65
+ H0 = 0.1 * little_h
66
+
67
+ def g(a):
68
+ a_dot = H0 * jnp.sqrt((omega_matter / a) + (omega_lambda * a**2))
69
+ return a**2 * a_dot # da/dt_hat
70
+
71
+ k1 = g(a_start)
72
+ k2 = g(a_start + 0.5 * dt_hat * k1)
73
+ a = a_start + dt_hat * k2
74
+ return a
jaxion/gravity.py CHANGED
@@ -1,9 +1,10 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for gravity calculations
4
5
 
5
6
 
6
7
  def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
7
- Vhat = -jnp.fft.fftn(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
8
- V = jnp.real(jnp.fft.ifftn(Vhat))
8
+ Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
9
+ V = jnp.real(jd.fft.pifft3d(Vhat))
9
10
  return V
jaxion/hydro.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for hydro simulation
4
5
 
@@ -89,11 +90,11 @@ def get_flux(rho_L, vx_L, vy_L, vz_L, rho_R, vx_R, vy_R, vz_R, cs):
89
90
 
90
91
 
91
92
  def hydro_accelerate(vx, vy, vz, V, kx, ky, kz, dt):
92
- V_hat = jnp.fft.fftn(V)
93
+ V_hat = jd.fft.pfft3d(V)
93
94
 
94
- ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
95
- ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
96
- az = -jnp.real(jnp.fft.ifftn(1.0j * kz * V_hat))
95
+ ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
96
+ ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
97
+ az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
97
98
 
98
99
  vx += ax * dt
99
100
  vy += ay * dt
@@ -28,7 +28,7 @@
28
28
  "domain": {
29
29
  "box_size": {
30
30
  "default": 10.0,
31
- "description": "periodic domain box size (kpc)."
31
+ "description": "periodic domain box size [kpc]."
32
32
  },
33
33
  "resolution_base": {
34
34
  "default": 32,
@@ -40,38 +40,75 @@
40
40
  }
41
41
  },
42
42
  "time": {
43
- "start": 0.0,
44
- "end": 1.0,
45
- "adaptive": false
43
+ "start": {
44
+ "default": 0.0,
45
+ "description": "simulation start time [kpc/(km/s)] or [redshift] (cosmology=true)."
46
+ },
47
+ "end": {
48
+ "default": 1.0,
49
+ "description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
50
+ },
51
+ "adaptive": {
52
+ "default": false,
53
+ "description": "switch on for adaptive time stepping."
54
+ }
46
55
  },
47
56
  "output": {
48
- "path": "./checkpoints",
49
- "num_checkpoints": 100,
50
- "save": true,
51
- "plot_dynamic_range": 100.0
57
+ "path": {
58
+ "default": "./checkpoints",
59
+ "description": "path to output directory."
60
+ },
61
+ "num_checkpoints": {
62
+ "default": 100,
63
+ "description": "number of checkpoints to save."
64
+ },
65
+ "save": {
66
+ "default": true,
67
+ "description": "switch on to save checkpoints."
68
+ },
69
+ "plot_dynamic_range": {
70
+ "default": 100.0,
71
+ "description": "dynamic range for plotting."
72
+ }
52
73
  },
53
74
  "quantum": {
54
75
  "m_22": {
55
76
  "default": 1.0,
56
- "description": "axion mass in units of 10^{-22} eV."
77
+ "description": "axion mass [10^{-22} eV]."
57
78
  }
58
79
  },
59
80
  "hydro": {
60
81
  "sound_speed": {
61
82
  "default": 1.0,
62
- "description": "Isothermal sound speed (km/s)."
83
+ "description": "isothermal sound speed [km/s]."
63
84
  }
64
85
  },
65
86
  "particles": {
66
- "num_particles": 0,
67
- "particle_mass": 1.0
87
+ "num_particles": {
88
+ "default": 0,
89
+ "description": "number of particles."
90
+ },
91
+ "particle_mass": {
92
+ "default": 1.0,
93
+ "description": "particle mass [M_sun]."
94
+ }
68
95
  },
69
96
  "cosmology": {
70
- "Omega_m": 0.3,
71
- "Omega_Lambda": 0.7
97
+ "omega_matter": {
98
+ "default": 0.3,
99
+ "description": "matter density parameter."
100
+ },
101
+ "omega_lambda": {
102
+ "default": 0.7,
103
+ "description": "dark energy density parameter."
104
+ },
105
+ "little_h": {
106
+ "default": 0.7,
107
+ "description": "Hubble parameter little h (H_0 = 100 h km/s/Mpc)."
108
+ }
72
109
  },
73
110
  "version": {
74
111
  "default": "unknown",
75
- "description": "version of the jaxion library used to create these parameters (automatically detected)."
112
+ "description": "jaxion version used (auto detected)."
76
113
  }
77
114
  }
jaxion/particles.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
+ import jaxdecomp as jd
3
4
 
4
5
  # Pure functions for particle-mesh calculations (stars, BHs, ...)
5
6
 
@@ -64,10 +65,10 @@ def get_acceleration(pos, V, kx, ky, kz, dx):
64
65
  i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
65
66
 
66
67
  # find accelerations on the grid
67
- V_hat = jnp.fft.fftn(V)
68
- ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
69
- ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
70
- az = -jnp.real(jnp.fft.ifftn(1.0j * kz * V_hat))
68
+ V_hat = jd.fft.pfft3d(V)
69
+ ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
70
+ ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
71
+ az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
71
72
  a_grid = jnp.stack((ax, ay, az), axis=-1)
72
73
 
73
74
  # interpolate the accelerations to the particle positions
jaxion/quantum.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import jax.numpy as jnp
2
+ import jaxdecomp as jd
2
3
 
3
4
  # Pure functions for quantum simulation
4
5
 
@@ -9,7 +10,86 @@ def quantum_kick(psi, V, m_per_hbar, dt):
9
10
 
10
11
 
11
12
  def quantum_drift(psi, k_sq, m_per_hbar, dt):
12
- psi_hat = jnp.fft.fftn(psi)
13
+ psi_hat = jd.fft.pfft3d(psi)
13
14
  psi_hat = jnp.exp(dt * (-1.0j * k_sq / m_per_hbar / 2.0)) * psi_hat
14
- psi = jnp.fft.ifftn(psi_hat)
15
+ psi = jd.fft.pifft3d(psi_hat)
15
16
  return psi
17
+
18
+
19
+ def get_gradient(psi, kx, ky, kz):
20
+ """
21
+ Computes gradient of wavefunction psi
22
+ psi: jnp.ndarray (3D)
23
+ Returns: grad_psi_x, grad_psi_y, grad_psi_z
24
+ """
25
+ psi_hat = jd.fft.pfft3d(psi)
26
+ grad_psi_x_hat = 1.0j * kx * psi_hat
27
+ grad_psi_y_hat = 1.0j * ky * psi_hat
28
+ grad_psi_z_hat = 1.0j * kz * psi_hat
29
+
30
+ grad_psi_x = jd.fft.pifft3d(grad_psi_x_hat)
31
+ grad_psi_y = jd.fft.pifft3d(grad_psi_y_hat)
32
+ grad_psi_z = jd.fft.pifft3d(grad_psi_z_hat)
33
+
34
+ return grad_psi_x, grad_psi_y, grad_psi_z
35
+
36
+
37
+ def quantum_velocity(psi, box_size, m_per_hbar):
38
+ """
39
+ Compute the velocity from the wave-function
40
+ v = nabla S / m
41
+ psi = sqrt(rho) exp(i S / hbar)
42
+ """
43
+ N = psi.shape[0]
44
+ dx = box_size / N
45
+
46
+ S_per_hbar = jnp.angle(psi)
47
+
48
+ # Central differences with phase unwrapping
49
+ # vx = jnp.roll(S_per_hbar, -1, axis=0) - jnp.roll(S_per_hbar, 1, axis=0)
50
+ # vy = jnp.roll(S_per_hbar, -1, axis=1) - jnp.roll(S_per_hbar, 1, axis=1)
51
+ # vz = jnp.roll(S_per_hbar, -1, axis=2) - jnp.roll(S_per_hbar, 1, axis=2)
52
+ # vx = jnp.where(vx > jnp.pi, vx - 2 * jnp.pi, vx)
53
+ # vx = jnp.where(vx <= -jnp.pi, vx + 2 * jnp.pi, vx)
54
+ # vy = jnp.where(vy > jnp.pi, vy - 2 * jnp.pi, vy)
55
+ # vy = jnp.where(vy <= -jnp.pi, vy + 2 * jnp.pi, vy)
56
+ # vz = jnp.where(vz > jnp.pi, vz - 2 * jnp.pi, vz)
57
+ # vz = jnp.where(vz <= -jnp.pi, vz + 2 * jnp.pi, vz)
58
+ # vx = vx / (2.0 * dx) / m_per_hbar
59
+ # vy = vy / (2.0 * dx) / m_per_hbar
60
+ # vz = vz / (2.0 * dx) / m_per_hbar
61
+
62
+ # Forward differences
63
+ vx1 = jnp.roll(S_per_hbar, -1, axis=0) - S_per_hbar
64
+ vy1 = jnp.roll(S_per_hbar, -1, axis=1) - S_per_hbar
65
+ vz1 = jnp.roll(S_per_hbar, -1, axis=2) - S_per_hbar
66
+ vx1 = jnp.where(vx1 > jnp.pi, vx1 - 2 * jnp.pi, vx1)
67
+ vx1 = jnp.where(vx1 <= -jnp.pi, vx1 + 2 * jnp.pi, vx1)
68
+ vy1 = jnp.where(vy1 > jnp.pi, vy1 - 2 * jnp.pi, vy1)
69
+ vy1 = jnp.where(vy1 <= -jnp.pi, vy1 + 2 * jnp.pi, vy1)
70
+ vz1 = jnp.where(vz1 > jnp.pi, vz1 - 2 * jnp.pi, vz1)
71
+ vz1 = jnp.where(vz1 <= -jnp.pi, vz1 + 2 * jnp.pi, vz1)
72
+ vx1 = vx1 / dx / m_per_hbar
73
+ vy1 = vy1 / dx / m_per_hbar
74
+ vz1 = vz1 / dx / m_per_hbar
75
+
76
+ # Backward differences
77
+ vx2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=0)
78
+ vy2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=1)
79
+ vz2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=2)
80
+ vx2 = jnp.where(vx2 > jnp.pi, vx2 - 2 * jnp.pi, vx2)
81
+ vx2 = jnp.where(vx2 <= -jnp.pi, vx2 + 2 * jnp.pi, vx2)
82
+ vy2 = jnp.where(vy2 > jnp.pi, vy2 - 2 * jnp.pi, vy2)
83
+ vy2 = jnp.where(vy2 <= -jnp.pi, vy2 + 2 * jnp.pi, vy2)
84
+ vz2 = jnp.where(vz2 > jnp.pi, vz2 - 2 * jnp.pi, vz2)
85
+ vz2 = jnp.where(vz2 <= -jnp.pi, vz2 + 2 * jnp.pi, vz2)
86
+ vx2 = vx2 / dx / m_per_hbar
87
+ vy2 = vy2 / dx / m_per_hbar
88
+ vz2 = vz2 / dx / m_per_hbar
89
+
90
+ # Average forward and backward
91
+ vx = 0.5 * (vx1 + vx2)
92
+ vy = 0.5 * (vy1 + vy2)
93
+ vz = 0.5 * (vz1 + vz2)
94
+
95
+ return vx, vy, vz
jaxion/simulation.py CHANGED
@@ -6,11 +6,19 @@ import json
6
6
  import time
7
7
 
8
8
  from .constants import constants
9
- from .quantum import quantum_kick, quantum_drift
9
+ from .quantum import quantum_kick, quantum_drift, quantum_velocity
10
10
  from .gravity import calculate_gravitational_potential
11
11
  from .hydro import hydro_fluxes, hydro_accelerate
12
12
  from .particles import particles_accelerate, particles_drift, bin_particles
13
- from .utils import set_up_parameters, print_parameters
13
+ from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
14
+ from .utils import (
15
+ set_up_parameters,
16
+ print_parameters,
17
+ xmeshgrid,
18
+ xmeshgrid_transpose,
19
+ xzeros,
20
+ xones,
21
+ )
14
22
  from .visualization import plot_sim
15
23
 
16
24
 
@@ -21,10 +29,20 @@ class Simulation:
21
29
  Parameters
22
30
  ----------
23
31
  params (dict): The Python dictionary that contains the simulation parameters.
32
+ Params can also be a string path to a checkpoint directory to load a saved simulation.
24
33
 
25
34
  """
26
35
 
27
- def __init__(self, params):
36
+ def __init__(self, params, sharding=None):
37
+ # allow loading directly from a checkpoint path
38
+ load_from_checkpoint = False
39
+ checkpoint_dir = ""
40
+ if isinstance(params, str):
41
+ load_from_checkpoint = True
42
+ checkpoint_dir = os.path.join(os.getcwd(), params)
43
+ with open(os.path.join(checkpoint_dir, "params.json"), "r") as f:
44
+ params = json.load(f)
45
+
28
46
  # start from default simulation parameters and update with user params
29
47
  self._params = set_up_parameters(params)
30
48
 
@@ -32,38 +50,80 @@ class Simulation:
32
50
  if self.resolution % 2 != 0:
33
51
  raise ValueError("Resolution must be divisible by 2.")
34
52
 
53
+ if self.params["time"]["adaptive"]:
54
+ raise NotImplementedError("Adaptive time stepping is not yet implemented.")
55
+
56
+ if self.params["physics"]["cosmology"]:
57
+ if (
58
+ self.params["physics"]["hydro"]
59
+ or self.params["physics"]["particles"]
60
+ or self.params["physics"]["external_potential"]
61
+ ):
62
+ raise NotImplementedError(
63
+ "Cosmological hydro/particles/external_potential physics is not yet implemented."
64
+ )
65
+
66
+ if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
67
+ if sharding is not None:
68
+ raise NotImplementedError(
69
+ "hydro/particles sharding is not yet implemented."
70
+ )
71
+
35
72
  if self.params["output"]["save"]:
36
- print("Simulation initialized with parameters:")
37
- print_parameters(self.params)
73
+ if jax.process_index() == 0:
74
+ print("Simulation parameters:")
75
+ print_parameters(self.params)
76
+
77
+ # jitted functions
78
+ self.xmeshgrid_jit = jax.jit(
79
+ xmeshgrid, in_shardings=None, out_shardings=sharding
80
+ )
81
+ self.xmeshgrid_transpose_jit = jax.jit(
82
+ xmeshgrid_transpose, in_shardings=None, out_shardings=sharding
83
+ )
84
+ self.xzeros_jit = jax.jit(
85
+ xzeros, static_argnums=0, in_shardings=None, out_shardings=sharding
86
+ )
87
+ self.xones_jit = jax.jit(
88
+ xones, static_argnums=0, in_shardings=None, out_shardings=sharding
89
+ )
38
90
 
39
91
  # simulation state
40
92
  self.state = {}
41
93
  self.state["t"] = 0.0
94
+ if self.params["physics"]["cosmology"]:
95
+ self.state["redshift"] = 0.0
42
96
  if self.params["physics"]["quantum"]:
43
- self.state["psi"] = (
44
- jnp.zeros((self.resolution, self.resolution, self.resolution)) * 1j
45
- )
97
+ self.state["psi"] = self.xzeros_jit(self.resolution) * 1j
46
98
  if self.params["physics"]["external_potential"]:
47
- self.state["V_ext"] = jnp.zeros(
48
- (self.resolution, self.resolution, self.resolution)
49
- )
99
+ self.state["V_ext"] = self.xzeros_jit(self.resolution)
50
100
  if self.params["physics"]["hydro"]:
51
101
  self.state["rho"] = jnp.zeros(
52
- (self.resolution, self.resolution, self.resolution)
102
+ (self.resolution, self.resolution, self.resolution),
53
103
  )
54
104
  self.state["vx"] = jnp.zeros(
55
- (self.resolution, self.resolution, self.resolution)
105
+ (self.resolution, self.resolution, self.resolution),
56
106
  )
57
107
  self.state["vy"] = jnp.zeros(
58
- (self.resolution, self.resolution, self.resolution)
108
+ (self.resolution, self.resolution, self.resolution),
59
109
  )
60
110
  self.state["vz"] = jnp.zeros(
61
- (self.resolution, self.resolution, self.resolution)
111
+ (self.resolution, self.resolution, self.resolution),
62
112
  )
63
113
  if self.params["physics"]["particles"]:
64
114
  self.state["pos"] = jnp.zeros((self.num_particles, 3))
65
115
  self.state["vel"] = jnp.zeros((self.num_particles, 3))
66
116
 
117
+ if load_from_checkpoint:
118
+ options = ocp.CheckpointManagerOptions()
119
+ async_checkpoint_manager = ocp.CheckpointManager(
120
+ checkpoint_dir, options=options
121
+ )
122
+ step = async_checkpoint_manager.latest_step()
123
+ self.state = async_checkpoint_manager.restore(
124
+ step, args=ocp.args.StandardRestore(self.state)
125
+ )
126
+
67
127
  @property
68
128
  def resolution(self):
69
129
  return (
@@ -104,19 +164,24 @@ class Simulation:
104
164
  def grid(self):
105
165
  hx = 0.5 * self.dx
106
166
  x_lin = jnp.linspace(hx, self.box_size - hx, self.resolution)
107
- X, Y, Z = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
108
- return X, Y, Z
167
+ xx, yy, zz = self.xmeshgrid_jit(x_lin)
168
+ return xx, yy, zz
109
169
 
110
170
  @property
111
171
  def kgrid(self):
112
172
  nx = self.resolution
113
173
  k_lin = (2.0 * jnp.pi / self.box_size) * jnp.arange(-nx / 2, nx / 2)
114
- kx, ky, kz = jnp.meshgrid(k_lin, k_lin, k_lin, indexing="ij")
174
+ kx, ky, kz = self.xmeshgrid_transpose_jit(k_lin)
115
175
  kx = jnp.fft.ifftshift(kx)
116
176
  ky = jnp.fft.ifftshift(ky)
117
177
  kz = jnp.fft.ifftshift(kz)
118
178
  return kx, ky, kz
119
179
 
180
+ @property
181
+ def quantum_velocity(self):
182
+ m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
183
+ return quantum_velocity(self.state["psi"], self.box_size, m_per_hbar)
184
+
120
185
  def _calc_rho_bar(self, state):
121
186
  rho_bar = 0.0
122
187
  if self.params["physics"]["quantum"]:
@@ -133,7 +198,7 @@ class Simulation:
133
198
  def _calc_grav_potential(self, state, k_sq):
134
199
  G = constants["gravitational_constant"]
135
200
  m_particle = self.params["particles"]["particle_mass"]
136
- rho_bar = self._calc_rho_bar(self.state)
201
+ rho_bar = self._calc_rho_bar(state)
137
202
  rho_tot = 0.0
138
203
  if self.params["physics"]["quantum"]:
139
204
  rho_tot += jnp.abs(state["psi"]) ** 2
@@ -141,6 +206,10 @@ class Simulation:
141
206
  rho_tot += state["rho"]
142
207
  if self.params["physics"]["particles"]:
143
208
  rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
209
+ if self.params["physics"]["cosmology"]:
210
+ scale_factor = 1.0 / (1.0 + state["redshift"])
211
+ rho_bar *= scale_factor
212
+ rho_tot *= scale_factor
144
213
  return calculate_gravitational_potential(rho_tot, k_sq, G, rho_bar)
145
214
 
146
215
  @property
@@ -167,19 +236,40 @@ class Simulation:
167
236
  # Simulation parameters
168
237
  dx = self.dx
169
238
  m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
239
+ box_size = self.box_size
170
240
 
171
241
  dt_fac = 1.0
172
242
  dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
243
+ t_start = self.params["time"]["start"]
173
244
  t_end = self.params["time"]["end"]
245
+ t_span = t_end - t_start
246
+ state["t"] = t_start
247
+
248
+ # cosmology
249
+ if self.params["physics"]["cosmology"]:
250
+ z_start = self.params["time"]["start"]
251
+ z_end = self.params["time"]["end"]
252
+ omega_matter = self.params["cosmology"]["omega_matter"]
253
+ omega_lambda = self.params["cosmology"]["omega_lambda"]
254
+ little_h = self.params["cosmology"]["little_h"]
255
+ t_span = get_supercomoving_time_interval(
256
+ z_start, z_end, omega_matter, omega_lambda, little_h
257
+ )
258
+ state["t"] = 0.0
259
+ state["redshift"] = z_start
174
260
 
261
+ # hydro
175
262
  c_sound = self.params["hydro"]["sound_speed"]
176
- box_size = self.box_size
177
263
 
178
264
  # round up to the nearest multiple of num_checkpoints
179
265
  num_checkpoints = self.params["output"]["num_checkpoints"]
180
- nt = int(round(round(t_end / dt_kin) / num_checkpoints) * num_checkpoints)
266
+ nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
181
267
  nt_sub = int(round(nt / num_checkpoints))
182
- dt = t_end / nt
268
+ dt = t_span / nt
269
+
270
+ # distributed arrays (fixed) needed for calculations
271
+ kx, ky, kz = None, None, None
272
+ k_sq = None
183
273
 
184
274
  # Fourier space variables
185
275
  if self.params["physics"]["gravity"] or self.params["physics"]["quantum"]:
@@ -192,10 +282,14 @@ class Simulation:
192
282
  checkpoint_dir = checkpoint_dir = os.path.join(
193
283
  os.getcwd(), self.params["output"]["path"]
194
284
  )
195
- path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
285
+ path = os.path.join(os.getcwd(), checkpoint_dir)
286
+ if jax.process_index() == 0:
287
+ path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
196
288
  async_checkpoint_manager = ocp.CheckpointManager(path, options=options)
197
289
 
198
- def _kick(state, dt):
290
+ carry = (state, kx, ky, kz, k_sq)
291
+
292
+ def _kick(state, kx, ky, kz, k_sq, dt):
199
293
  # Kick (half-step)
200
294
  if (
201
295
  self.params["physics"]["gravity"]
@@ -222,7 +316,7 @@ class Simulation:
222
316
  state["vel"], state["pos"], V, kx, ky, kz, dx, dt
223
317
  )
224
318
 
225
- def _drift(state, dt):
319
+ def _drift(state, k_sq, dt):
226
320
  # Drift (full-step)
227
321
  if self.params["physics"]["quantum"]:
228
322
  state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
@@ -234,20 +328,30 @@ class Simulation:
234
328
  state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
235
329
 
236
330
  @jax.jit
237
- def _update(_, state):
331
+ def _update(_, carry):
238
332
  # Update the simulation state by one timestep
239
333
  # according to a 2nd-order `kick-drift-kick` scheme
240
- _kick(state, 0.5 * dt)
241
- _drift(state, dt)
242
- _kick(state, 0.5 * dt)
243
-
244
- # update time
334
+ state, kx, ky, kz, k_sq = carry
335
+ _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
336
+ _drift(state, k_sq, dt)
337
+ # update time & redshift
245
338
  state["t"] += dt
339
+ if self.params["physics"]["cosmology"]:
340
+ scale_factor = get_next_scale_factor(
341
+ state["redshift"],
342
+ dt,
343
+ self.params["cosmology"]["omega_matter"],
344
+ self.params["cosmology"]["omega_lambda"],
345
+ self.params["cosmology"]["little_h"],
346
+ )
347
+ state["redshift"] = 1.0 / scale_factor - 1.0
348
+ _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
246
349
 
247
- return state
350
+ return state, kx, ky, kz, k_sq
248
351
 
249
352
  # save initial state
250
- print("Starting simulation ...")
353
+ if jax.process_index() == 0:
354
+ print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
251
355
  if self.params["output"]["save"]:
252
356
  with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
253
357
  json.dump(self.params, f, indent=2)
@@ -259,7 +363,8 @@ class Simulation:
259
363
  t_start_timer = time.time()
260
364
  if self.params["output"]["save"]:
261
365
  for i in range(1, num_checkpoints + 1):
262
- state = jax.lax.fori_loop(0, nt_sub, _update, init_val=state)
366
+ carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
367
+ state, kx, ky, kz, k_sq = carry
263
368
  jax.block_until_ready(state)
264
369
  # save state
265
370
  async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
@@ -267,15 +372,17 @@ class Simulation:
267
372
  elapsed = time.time() - t_start_timer
268
373
  est_total = elapsed / i * num_checkpoints
269
374
  est_remaining = est_total - elapsed
270
- print(
271
- f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
272
- )
375
+ if jax.process_index() == 0:
376
+ print(
377
+ f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
378
+ )
273
379
  plot_sim(state, checkpoint_dir, i, self.params)
274
380
  async_checkpoint_manager.wait_until_finished()
275
381
  else:
276
382
  state = jax.lax.fori_loop(0, nt, _update, init_val=state)
277
383
  jax.block_until_ready(state)
278
- print("Simulation Run Time (s): ", time.time() - t_start_timer)
384
+ if jax.process_index() == 0:
385
+ print("Simulation Run Time (s): ", time.time() - t_start_timer)
279
386
 
280
387
  return state
281
388
 
jaxion/utils.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  import importlib.resources
6
6
  import json
7
7
  from importlib.metadata import version
8
+ import jax.numpy as jnp
8
9
 
9
10
 
10
11
  def print_parameters(params):
@@ -90,3 +91,29 @@ def run_example_main(example_path, argv=None):
90
91
  os.chdir(old_cwd)
91
92
  sys.argv = old_argv
92
93
  return result
94
+
95
+
96
+ # Make a distributed meshgrid function
97
+ def xmeshgrid(x_lin):
98
+ xx, yy, zz = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
99
+ return xx, yy, zz
100
+
101
+
102
+ # NOTE: jaxdecomp (jd) has pfft3d transpose the axis (X, Y, Z) --> (Y, Z, X), and pifft3d undo it
103
+ # so the fourier space variables (e.g. kx, ky, kz) all need to be transposed
104
+
105
+
106
+ def xmeshgrid_transpose(x_lin):
107
+ xx, yy, zz = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
108
+ xx = jnp.transpose(xx, (1, 2, 0))
109
+ yy = jnp.transpose(yy, (1, 2, 0))
110
+ zz = jnp.transpose(zz, (1, 2, 0))
111
+ return xx, yy, zz
112
+
113
+
114
+ def xzeros(nx):
115
+ return jnp.zeros((nx, nx, nx))
116
+
117
+
118
+ def xones(nx):
119
+ return jnp.ones((nx, nx, nx))
jaxion/visualization.py CHANGED
@@ -1,3 +1,4 @@
1
+ import jax
1
2
  import jax.numpy as jnp
2
3
  import matplotlib.pyplot as plt
3
4
  import os
@@ -8,67 +9,82 @@ def plot_sim(state, checkpoint_dir, i, params):
8
9
 
9
10
  dynamic_range = params["output"]["plot_dynamic_range"]
10
11
 
12
+ # process distributed data
11
13
  if params["physics"]["quantum"]:
12
- plt.clf()
13
-
14
- # DM projection
15
14
  nx = state["psi"].shape[0]
16
- rho_bar = jnp.mean(jnp.abs(state["psi"]) ** 2)
17
- vmin = jnp.log10(rho_bar / dynamic_range)
18
- vmax = jnp.log10(rho_bar * dynamic_range)
15
+ rho_bar_dm = jnp.mean(jnp.abs(state["psi"]) ** 2)
16
+ rho_proj_dm = jax.experimental.multihost_utils.process_allgather(
17
+ jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2).T)
18
+ ).reshape(nx, nx)
19
+ if params["physics"]["hydro"]:
20
+ nx = state["rho"].shape[0]
21
+ rho_bar_gas = jnp.mean(state["rho"])
22
+ rho_proj_gas = jax.experimental.multihost_utils.process_allgather(
23
+ jnp.log10(jnp.mean(jnp.abs(state["rho"]) ** 2, axis=2).T)
24
+ ).reshape(nx, nx)
19
25
 
20
- rho_proj_dm = jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2)).T
21
- ax = plt.gca()
22
- ax.imshow(
23
- rho_proj_dm,
24
- cmap="inferno",
25
- origin="lower",
26
- vmin=vmin,
27
- vmax=vmax,
28
- extent=[0, nx, 0, nx],
29
- )
30
- if params["physics"]["particles"]:
31
- # draw particles
32
- box_size = params["domain"]["box_size"]
33
- sx = (state["pos"][:, 0] / box_size) * nx
34
- sy = (state["pos"][:, 1] / box_size) * nx
35
- plt.plot(sx, sy, color="cyan", marker=".", linestyle="None", markersize=2)
36
- ax.set_aspect("equal")
37
- ax.get_xaxis().set_visible(False)
38
- ax.get_yaxis().set_visible(False)
26
+ # create plot on process 0
27
+ if jax.process_index() == 0:
28
+ if params["physics"]["quantum"]:
29
+ plt.clf()
39
30
 
40
- plt.savefig(
41
- os.path.join(checkpoint_dir, f"dm{i:03d}.png"),
42
- bbox_inches="tight",
43
- pad_inches=0,
44
- )
45
- plt.close()
31
+ # DM projection
32
+ nx = state["psi"].shape[0]
33
+ vmin = jnp.log10(rho_bar_dm / dynamic_range)
34
+ vmax = jnp.log10(rho_bar_dm * dynamic_range)
46
35
 
47
- if params["physics"]["hydro"]:
48
- plt.clf()
36
+ ax = plt.gca()
37
+ ax.imshow(
38
+ rho_proj_dm,
39
+ cmap="inferno",
40
+ origin="lower",
41
+ vmin=vmin,
42
+ vmax=vmax,
43
+ extent=[0, nx, 0, nx],
44
+ )
45
+ if params["physics"]["particles"]:
46
+ # draw particles
47
+ box_size = params["domain"]["box_size"]
48
+ sx = (state["pos"][:, 0] / box_size) * nx
49
+ sy = (state["pos"][:, 1] / box_size) * nx
50
+ plt.plot(
51
+ sx, sy, color="cyan", marker=".", linestyle="None", markersize=2
52
+ )
53
+ ax.set_aspect("equal")
54
+ ax.get_xaxis().set_visible(False)
55
+ ax.get_yaxis().set_visible(False)
49
56
 
50
- # gas projection
51
- nx = state["rho"].shape[0]
52
- rho_bar = jnp.mean(state["rho"])
53
- vmin = jnp.log10(rho_bar / dynamic_range)
54
- vmax = jnp.log10(rho_bar * dynamic_range)
55
- rho_proj_gas = jnp.log10(jnp.mean(state["rho"], axis=2)).T
56
- ax = plt.gca()
57
- ax.imshow(
58
- rho_proj_gas,
59
- cmap="viridis",
60
- origin="lower",
61
- vmin=vmin,
62
- vmax=vmax,
63
- extent=[0, nx, 0, nx],
64
- )
65
- ax.set_aspect("equal")
66
- ax.get_xaxis().set_visible(False)
67
- ax.get_yaxis().set_visible(False)
57
+ plt.savefig(
58
+ os.path.join(checkpoint_dir, f"dm{i:03d}.png"),
59
+ bbox_inches="tight",
60
+ pad_inches=0,
61
+ )
62
+ plt.close()
63
+
64
+ if params["physics"]["hydro"]:
65
+ plt.clf()
66
+
67
+ # gas projection
68
+ nx = state["rho"].shape[0]
69
+ vmin = jnp.log10(rho_bar_gas / dynamic_range)
70
+ vmax = jnp.log10(rho_bar_gas * dynamic_range)
71
+
72
+ ax = plt.gca()
73
+ ax.imshow(
74
+ rho_proj_gas,
75
+ cmap="viridis",
76
+ origin="lower",
77
+ vmin=vmin,
78
+ vmax=vmax,
79
+ extent=[0, nx, 0, nx],
80
+ )
81
+ ax.set_aspect("equal")
82
+ ax.get_xaxis().set_visible(False)
83
+ ax.get_yaxis().set_visible(False)
68
84
 
69
- plt.savefig(
70
- os.path.join(checkpoint_dir, f"gas{i:03d}.png"),
71
- bbox_inches="tight",
72
- pad_inches=0,
73
- )
74
- plt.close()
85
+ plt.savefig(
86
+ os.path.join(checkpoint_dir, f"gas{i:03d}.png"),
87
+ bbox_inches="tight",
88
+ pad_inches=0,
89
+ )
90
+ plt.close()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxion
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: A differentiable simulation library for fuzzy dark matter in JAX
5
5
  Author-email: Philip Mocz <philip.mocz@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -21,6 +21,12 @@ Provides-Extra: cuda12
21
21
  Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
22
22
  Dynamic: license-file
23
23
 
24
+ <p align="center">
25
+ <a href="https://jaxion.readthedocs.io">
26
+ <img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
27
+ </a>
28
+ </p>
29
+
24
30
  # jaxion
25
31
 
26
32
  [![Repo Status][status-badge]][status-link]
@@ -44,9 +50,9 @@ A simple JAX-powered simulation library for numerical experiments of fuzzy dark
44
50
 
45
51
  Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
46
52
 
47
- ⚠️ Jaxion is currently being developed and is not yet ready for use. Check back later ⚠️
53
+ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
48
54
 
49
- Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seemlessly integrate with piplines for inverse-problems, inference, optimization, and coupling to ML models.
55
+ Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
50
56
 
51
57
 
52
58
  ## Getting started
@@ -66,6 +72,20 @@ pip install jaxion[cuda12]
66
72
  Check out the `examples/` directory for demonstrations of using Jaxion.
67
73
 
68
74
 
75
+ ## Examples
76
+
77
+ <p align="center">
78
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
79
+ <img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
80
+ </a>
81
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
82
+ <img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
83
+ </a>
84
+ <a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
85
+ <img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
86
+ </a>
87
+ </p>
88
+
69
89
  ## Links
70
90
 
71
91
  * [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
@@ -0,0 +1,17 @@
1
+ jaxion/__init__.py,sha256=Hdji1UQ47lG24Pqcy6UUq9L0-qy6m9Ax41L0vIYzBto,164
2
+ jaxion/analysis.py,sha256=UZpCvDr2ISDAgRoZFtuhvDk0QtN3mn5NxIoYYY8wzs8,1304
3
+ jaxion/constants.py,sha256=b0vgqtezIEGp4O5JRJo3YK3eQC8AN6hmV84-4pYWgmc,528
4
+ jaxion/cosmology.py,sha256=UC1McXNTXGoPRYXn0nI2-csVkJWL-ZBNoCa44oU1b4w,2681
5
+ jaxion/gravity.py,sha256=3brRZelKm-soXqk_Lt3SqhbZ00woJCraqwdMuR-KooA,291
6
+ jaxion/hydro.py,sha256=KoJ02tRpAc4V3Ofzw4zbHLRaE2GdIatbOBE04_LsSRw,6980
7
+ jaxion/params_default.json,sha256=9CJrhEPsv5zGEs7_WqFyuccCDipPCDhXgKzVdqOsOWE,2775
8
+ jaxion/particles.py,sha256=pMopGvoZ0J_3EviD0WnTMmiebU9h2_8IO-p6I-E5DEU,3980
9
+ jaxion/quantum.py,sha256=GWOpN6ipfEw-6Ah2zQpxS3oqeSt_iHMDSgnVYSjXY5E,3321
10
+ jaxion/simulation.py,sha256=3UrrHPJr0_39BiP86nQSl4ui8EsPHf9RrZ5hDobit5k,14884
11
+ jaxion/utils.py,sha256=rT7NM0FNEgFwN7oTgTb-jkR66Iw0xYTHHxcoikYd1ag,3572
12
+ jaxion/visualization.py,sha256=MJ3iqrvAem4RbozM79Ri2cNTIwKUV8o5X8M32FiDfbo,2928
13
+ jaxion-0.0.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
+ jaxion-0.0.3.dist-info/METADATA,sha256=1C3fwJHJEBZYMEhxmTkKd7z_ZUrF15E3P0oS5XxrdjA,3579
15
+ jaxion-0.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ jaxion-0.0.3.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
17
+ jaxion-0.0.3.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- jaxion/__init__.py,sha256=gHrU_SvIBIbuYt-MeEeyiMyuJKwLjTTiz513OZ7CDHQ,95
2
- jaxion/constants.py,sha256=b0vgqtezIEGp4O5JRJo3YK3eQC8AN6hmV84-4pYWgmc,528
3
- jaxion/gravity.py,sha256=DdVOhAaoIPLXtW7ALrMK2umLPkStzxJ778l1SD2yMkY,266
4
- jaxion/hydro.py,sha256=3ZFNDhIaPBw4uTBqFbjJotCLkuKG8AvRcY1T78QQ96U,6953
5
- jaxion/params_default.json,sha256=4-XNPN52phGxYsP6FOeIJH0ALtICjrQxApotz-NhJEc,1782
6
- jaxion/particles.py,sha256=7oJBVyddfOXUDYxQuGIs8_ByL2Hu6jkuWE1MWtshFjk,3953
7
- jaxion/quantum.py,sha256=91tEotQhX1CB3HCaRCDARRPX_dEx58GIUCluY5RyPvA,377
8
- jaxion/simulation.py,sha256=KcIi91LHJO5K4M0gxU4qc7e4JEeOyW1HUZOYwt8FOG0,10385
9
- jaxion/utils.py,sha256=LCVx5_N9-6HsbENENOEz8tI6Czl0Z8Bit2icf1k4024,2880
10
- jaxion/visualization.py,sha256=WKID2_LFbbqXwmW_TLTnVznKYHfjiDKNbjD9xS4GfLc,2209
11
- jaxion-0.0.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
- jaxion-0.0.2.dist-info/METADATA,sha256=uUe69xVa678F2uV28-CczduNdfoOtBpwhfuM_mJKXpM,2811
13
- jaxion-0.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
- jaxion-0.0.2.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
15
- jaxion-0.0.2.dist-info/RECORD,,
File without changes