classy-szfast 0.0.25.post10__tar.gz → 0.0.25.post12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/PKG-INFO +4 -4
  2. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/classy_szfast.py +4 -4
  3. classy_szfast-0.0.25.post12/classy_szfast/utils.py +222 -0
  4. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast.egg-info/PKG-INFO +4 -4
  5. classy_szfast-0.0.25.post12/classy_szfast.egg-info/requires.txt +7 -0
  6. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast.egg-info/top_level.txt +1 -0
  7. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/pyproject.toml +9 -5
  8. classy_szfast-0.0.25.post10/classy_szfast/utils.py +0 -73
  9. classy_szfast-0.0.25.post10/classy_szfast.egg-info/requires.txt +0 -7
  10. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/README.md +0 -0
  11. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/__init__.py +0 -0
  12. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/classy_sz.py +0 -0
  13. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/config.py +0 -0
  14. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/cosmopower.py +0 -0
  15. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/cosmopower_jax.py +0 -0
  16. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  17. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/custom_bias/__init__.py +0 -0
  18. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/custom_bias/custom_bias.py +0 -0
  19. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/custom_profiles/__init__.py +0 -0
  20. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  21. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/emulators_meta_data.py +0 -0
  22. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/pks_and_sigmas.py +0 -0
  23. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/restore_nn.py +0 -0
  24. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast/suppress_warnings.py +0 -0
  25. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast.egg-info/SOURCES.txt +0 -0
  26. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/classy_szfast.egg-info/dependency_links.txt +0 -0
  27. {classy_szfast-0.0.25.post10 → classy_szfast-0.0.25.post12}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post10
3
+ Version: 0.0.25.post12
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -10,7 +10,7 @@ Description-Content-Type: text/markdown
10
10
  Requires-Dist: numpy>=1.19.0
11
11
  Requires-Dist: Cython>=0.29.21
12
12
  Requires-Dist: tensorflow
13
- Requires-Dist: mcfit
14
- Requires-Dist: get_cosmopower_emus==0.0.15
15
- Requires-Dist: class_sz_data
13
+ Requires-Dist: mcfit>=0.0.22
14
+ Requires-Dist: get_cosmopower_emus>=0.0.15
15
+ Requires-Dist: class_sz_data>=0.0.13
16
16
  Requires-Dist: cosmopower-jax
@@ -1,5 +1,5 @@
1
1
  from .utils import *
2
- from .utils import Const
2
+ from .utils import Const, jax_gradient
3
3
  from .config import *
4
4
  import numpy as np
5
5
  from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
@@ -133,7 +133,7 @@ class Class_szfast(object):
133
133
  self.geomspace = jnp.geomspace
134
134
  self.arange = jnp.arange
135
135
  self.zeros = jnp.zeros
136
- self.gradient = jnp.gradient
136
+ self.gradient = jax_gradient
137
137
 
138
138
 
139
139
  else:
@@ -591,8 +591,8 @@ class Class_szfast(object):
591
591
 
592
592
  R, var[:,iz] = TophatVar(k, lowring=True)(P[:,iz], extrap=True)
593
593
 
594
- # dvar[:,iz] = self.gradient(var[:,iz], R) ## old form
595
- _, dvar[:,iz] = TophatVar(k, lowring=True, deriv=1)(P[:,iz]*k, extrap=True) # new form
594
+ dvar[:,iz] = self.gradient(var[:,iz], R) ## old form with jax gradient from inigo zubeldia if needed.
595
+ # _, dvar[:,iz] = TophatVar(k, lowring=True, deriv=1)(P[:,iz]*k, extrap=True) # new form doesnt seem to work accurately
596
596
 
597
597
  # dvar = dvar/(2.*np.sqrt(var))
598
598
  # print(dvar_grads/dvar)
@@ -0,0 +1,222 @@
1
+ import numpy as np
2
+ from datetime import datetime
3
+ import multiprocessing
4
+ import time
5
+ import functools
6
+ import re
7
+ from pkg_resources import resource_filename
8
+ import os
9
+ from scipy import optimize
10
+ from scipy.integrate import quad
11
+ from scipy.interpolate import interp1d
12
+ import math
13
+ from numpy import linalg as LA
14
+ import mcfit
15
+ from mcfit import P2xi
16
+ import jax
17
+ import jax.numpy as jnp
18
+ # import cosmopower
19
+ # import classy_sz as csz
20
+
21
+
22
+
23
+ from scipy.interpolate import LinearNDInterpolator
24
+ from scipy.interpolate import CloughTocher2DInterpolator
25
+
26
+ kb = 1.38064852e-23 #m2 kg s-2 K-1
27
+ clight = 299792458. #m/s
28
+ hplanck=6.62607004e-34 #m2 kg / s
29
+ firas_T0 = 2.728 #pivot temperature used in the Max Lkl Analysis
30
+ firas_T0_bf = 2.725 #best-fitting temperature
31
+
32
+ Tcmb_uk = 2.7255e6
33
+
34
+ G_newton = 6.674e-11
35
+ rho_crit_over_h2_in_GeV_per_cm3 = 1.0537e-5
36
+
37
+
38
+ nu_21_cm_in_GHz = 1./21.1*clight*1.e2/1.e9
39
+ x_21_cm = hplanck*nu_21_cm_in_GHz/kb/firas_T0_bf*1.e9
40
+
41
+ kappa_c = 2.1419 # 4M_2-3M_c see below eq. 9b of https://arxiv.org/pdf/1506.06582.pdf
42
+
43
+ beta_mu = 2.1923
44
+
45
+ G1 = np.pi**2./6
46
+ G2 = 2.4041
47
+ G3 = np.pi**4/15.
48
+ a_rho = G2/G3
49
+ alpha_mu = 2.*G1/3./G2 # = 1/beta_mu = π^2/18ζ(3) see eq. 4.15 CUSO lectures.
50
+
51
+ z_mu_era = 3e5
52
+ z_y_era = 5e4
53
+ z_reio_min = 6
54
+ z_reio_max = 25
55
+ z_recombination_min = 800
56
+ z_recombination_max = 1500
57
+
58
+ # Physical constants
59
+ # ------------------
60
+ # Light speed
61
+ class Const:
62
+ c_km_s = 299792.458 # speed of light
63
+ h_J_s = 6.626070040e-34 # Planck's constant
64
+ kB_J_K = 1.38064852e-23 # Boltzmann constant
65
+
66
+ _c_ = 2.99792458e8 # c in m/s
67
+ _Mpc_over_m_ = 3.085677581282e22 # conversion factor from meters to megaparsecs
68
+ _Gyr_over_Mpc_ = 3.06601394e2 # conversion factor from megaparsecs to gigayears
69
+ _G_ = 6.67428e-11 # Newton constant in m^3/Kg/s^2
70
+ _eV_ = 1.602176487e-19 # 1 eV expressed in J
71
+
72
+ # parameters entering in Stefan-Boltzmann constant sigma_B
73
+ _k_B_ = 1.3806504e-23
74
+ _h_P_ = 6.62606896e-34
75
+ _M_sun_ = 1.98855e30 # solar mass in kg
76
+
77
+
78
+
79
+ def jax_gradient(f, *varargs, axis=None, edge_order=1):
80
+ f = jnp.asarray(f)
81
+ N = f.ndim # number of dimensions
82
+ if axis is None:
83
+ axes = tuple(range(N))
84
+ else:
85
+ axes = jax.numpy._normalize_axis_index(axis, N)
86
+ len_axes = len(axes)
87
+ n = len(varargs)
88
+ if n == 0:
89
+ # no spacing argument - use 1 in all axes
90
+ dx = [1.0] * len_axes
91
+ elif n == 1 and jnp.ndim(varargs[0]) == 0:
92
+ # single scalar for all axes
93
+ dx = varargs * len_axes
94
+ elif n == len_axes:
95
+ # scalar or 1d array for each axis
96
+ dx = list(varargs)
97
+ for i, distances in enumerate(dx):
98
+ distances = jnp.asarray(distances)
99
+ if distances.ndim == 0:
100
+ continue
101
+ elif distances.ndim != 1:
102
+ raise ValueError("distances must be either scalars or 1D")
103
+ if len(distances) != f.shape[axes[i]]:
104
+ raise ValueError("when 1D, distances must match "
105
+ "the length of the corresponding dimension")
106
+ if jnp.issubdtype(distances.dtype, jnp.integer):
107
+ # Convert jax integer types to float64 to avoid modular
108
+ # arithmetic in np.diff(distances).
109
+ distances = distances.astype(jnp.float64)
110
+ diffx = jnp.diff(distances)
111
+ # if distances are constant reduce to the scalar case
112
+ # since it brings a consistent speedup
113
+ if (diffx == diffx[0]).all():
114
+ diffx = diffx[0]
115
+ dx[i] = diffx
116
+ else:
117
+ raise TypeError("invalid number of arguments")
118
+ if edge_order > 2:
119
+ raise ValueError("'edge_order' greater than 2 not supported")
120
+ outvals = []
121
+ slice1 = [slice(None)] * N
122
+ slice2 = [slice(None)] * N
123
+ slice3 = [slice(None)] * N
124
+ slice4 = [slice(None)] * N
125
+ otype = f.dtype
126
+ # All other types convert to floating point.
127
+ if jnp.issubdtype(otype, jnp.integer):
128
+ f = f.astype(jnp.float64)
129
+ otype = jnp.float64
130
+ for axis, ax_dx in zip(axes, dx):
131
+ if f.shape[axis] < edge_order + 1:
132
+ raise ValueError(
133
+ "Shape of array too small to calculate a numerical gradient, "
134
+ "at least (edge_order + 1) elements are required.")
135
+ # result allocation
136
+ out = jnp.empty_like(f, dtype=otype)
137
+ uniform_spacing = jnp.ndim(ax_dx) == 0
138
+ # Numerical differentiation: 2nd order interior
139
+ slice1[axis] = slice(1, -1)
140
+ slice2[axis] = slice(None, -2)
141
+ slice3[axis] = slice(1, -1)
142
+ slice4[axis] = slice(2, None)
143
+ if uniform_spacing:
144
+ out = out.at[tuple(slice1)].set(
145
+ (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx)
146
+ )
147
+ else:
148
+ dx1 = ax_dx[:-1]
149
+ dx2 = ax_dx[1:]
150
+ a = -(dx2) / (dx1 * (dx1 + dx2))
151
+ b = (dx2 - dx1) / (dx1 * dx2)
152
+ c = dx1 / (dx2 * (dx1 + dx2))
153
+ shape = [1] * N
154
+ shape[axis] = -1
155
+ a = a.reshape(shape)
156
+ b = b.reshape(shape)
157
+ c = c.reshape(shape)
158
+ out = out.at[tuple(slice1)].set(
159
+ a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
160
+ )
161
+ # Numerical differentiation: 1st order edges
162
+ if edge_order == 1:
163
+ slice1[axis] = 0
164
+ slice2[axis] = 1
165
+ slice3[axis] = 0
166
+ dx_0 = ax_dx if uniform_spacing else ax_dx[0]
167
+ out = out.at[tuple(slice1)].set(
168
+ (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
169
+ )
170
+ slice1[axis] = -1
171
+ slice2[axis] = -1
172
+ slice3[axis] = -2
173
+ dx_n = ax_dx if uniform_spacing else ax_dx[-1]
174
+ out = out.at[tuple(slice1)].set(
175
+ (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n
176
+ )
177
+ # Numerical differentiation: 2nd order edges
178
+ else:
179
+ slice1[axis] = 0
180
+ slice2[axis] = 0
181
+ slice3[axis] = 1
182
+ slice4[axis] = 2
183
+ if uniform_spacing:
184
+ a = -1.5 / ax_dx
185
+ b = 2. / ax_dx
186
+ c = -0.5 / ax_dx
187
+ else:
188
+ dx1 = ax_dx[0]
189
+ dx2 = ax_dx[1]
190
+ a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2))
191
+ b = (dx1 + dx2) / (dx1 * dx2)
192
+ c = -dx1 / (dx2 * (dx1 + dx2))
193
+ out = out.at[tuple(slice1)].set(
194
+ a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
195
+ )
196
+ slice1[axis] = -1
197
+ slice2[axis] = -3
198
+ slice3[axis] = -2
199
+ slice4[axis] = -1
200
+ if uniform_spacing:
201
+ a = 0.5 / ax_dx
202
+ b = -2. / ax_dx
203
+ c = 1.5 / ax_dx
204
+ else:
205
+ dx1 = ax_dx[-2]
206
+ dx2 = ax_dx[-1]
207
+ a = dx2 / (dx1 * (dx1 + dx2))
208
+ b = -(dx2 + dx1) / (dx1 * dx2)
209
+ c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2))
210
+ out = out.at[tuple(slice1)].set(
211
+ a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
212
+ )
213
+ outvals.append(out)
214
+ # reset the slice object in this dimension to ":"
215
+ slice1[axis] = slice(None)
216
+ slice2[axis] = slice(None)
217
+ slice3[axis] = slice(None)
218
+ slice4[axis] = slice(None)
219
+ if len_axes == 1:
220
+ return outvals[0]
221
+ ret = tuple(outvals)
222
+ print("return",ret)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post10
3
+ Version: 0.0.25.post12
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -10,7 +10,7 @@ Description-Content-Type: text/markdown
10
10
  Requires-Dist: numpy>=1.19.0
11
11
  Requires-Dist: Cython>=0.29.21
12
12
  Requires-Dist: tensorflow
13
- Requires-Dist: mcfit
14
- Requires-Dist: get_cosmopower_emus==0.0.15
15
- Requires-Dist: class_sz_data
13
+ Requires-Dist: mcfit>=0.0.22
14
+ Requires-Dist: get_cosmopower_emus>=0.0.15
15
+ Requires-Dist: class_sz_data>=0.0.13
16
16
  Requires-Dist: cosmopower-jax
@@ -0,0 +1,7 @@
1
+ numpy>=1.19.0
2
+ Cython>=0.29.21
3
+ tensorflow
4
+ mcfit>=0.0.22
5
+ get_cosmopower_emus>=0.0.15
6
+ class_sz_data>=0.0.13
7
+ cosmopower-jax
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.25.post10"
6
+ version = "0.0.25.post12"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]
@@ -13,12 +13,16 @@ dependencies = [
13
13
  "numpy>=1.19.0",
14
14
  "Cython>=0.29.21",
15
15
  "tensorflow",
16
- "mcfit",
17
- "get_cosmopower_emus==0.0.15",
18
- "class_sz_data",
16
+ "mcfit>=0.0.22",
17
+ "get_cosmopower_emus>=0.0.15",
18
+ "class_sz_data>=0.0.13",
19
19
  "cosmopower-jax"
20
20
  ]
21
21
 
22
22
  [project.urls]
23
23
  Homepage = "https://github.com/CLASS-SZ"
24
- GitHub = "https://github.com/CLASS-SZ"
24
+ GitHub = "https://github.com/CLASS-SZ"
25
+
26
+
27
+ [tool.setuptools.packages.find]
28
+ where = ["."]
@@ -1,73 +0,0 @@
1
- import numpy as np
2
- from datetime import datetime
3
- import multiprocessing
4
- import time
5
- import functools
6
- import re
7
- from pkg_resources import resource_filename
8
- import os
9
- from scipy import optimize
10
- from scipy.integrate import quad
11
- from scipy.interpolate import interp1d
12
- import math
13
- from numpy import linalg as LA
14
- import mcfit
15
- from mcfit import P2xi
16
- # import cosmopower
17
- # import classy_sz as csz
18
-
19
-
20
-
21
- from scipy.interpolate import LinearNDInterpolator
22
- from scipy.interpolate import CloughTocher2DInterpolator
23
-
24
- kb = 1.38064852e-23 #m2 kg s-2 K-1
25
- clight = 299792458. #m/s
26
- hplanck=6.62607004e-34 #m2 kg / s
27
- firas_T0 = 2.728 #pivot temperature used in the Max Lkl Analysis
28
- firas_T0_bf = 2.725 #best-fitting temperature
29
-
30
- Tcmb_uk = 2.7255e6
31
-
32
- G_newton = 6.674e-11
33
- rho_crit_over_h2_in_GeV_per_cm3 = 1.0537e-5
34
-
35
-
36
- nu_21_cm_in_GHz = 1./21.1*clight*1.e2/1.e9
37
- x_21_cm = hplanck*nu_21_cm_in_GHz/kb/firas_T0_bf*1.e9
38
-
39
- kappa_c = 2.1419 # 4M_2-3M_c see below eq. 9b of https://arxiv.org/pdf/1506.06582.pdf
40
-
41
- beta_mu = 2.1923
42
-
43
- G1 = np.pi**2./6
44
- G2 = 2.4041
45
- G3 = np.pi**4/15.
46
- a_rho = G2/G3
47
- alpha_mu = 2.*G1/3./G2 # = 1/beta_mu = π^2/18ζ(3) see eq. 4.15 CUSO lectures.
48
-
49
- z_mu_era = 3e5
50
- z_y_era = 5e4
51
- z_reio_min = 6
52
- z_reio_max = 25
53
- z_recombination_min = 800
54
- z_recombination_max = 1500
55
-
56
- # Physical constants
57
- # ------------------
58
- # Light speed
59
- class Const:
60
- c_km_s = 299792.458 # speed of light
61
- h_J_s = 6.626070040e-34 # Planck's constant
62
- kB_J_K = 1.38064852e-23 # Boltzmann constant
63
-
64
- _c_ = 2.99792458e8 # c in m/s
65
- _Mpc_over_m_ = 3.085677581282e22 # conversion factor from meters to megaparsecs
66
- _Gyr_over_Mpc_ = 3.06601394e2 # conversion factor from megaparsecs to gigayears
67
- _G_ = 6.67428e-11 # Newton constant in m^3/Kg/s^2
68
- _eV_ = 1.602176487e-19 # 1 eV expressed in J
69
-
70
- # parameters entering in Stefan-Boltzmann constant sigma_B
71
- _k_B_ = 1.3806504e-23
72
- _h_P_ = 6.62606896e-34
73
- _M_sun_ = 1.98855e30 # solar mass in kg
@@ -1,7 +0,0 @@
1
- numpy>=1.19.0
2
- Cython>=0.29.21
3
- tensorflow
4
- mcfit
5
- get_cosmopower_emus==0.0.15
6
- class_sz_data
7
- cosmopower-jax