classy-szfast 0.0.25.post10__py3-none-any.whl → 0.0.25.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from .utils import *
2
- from .utils import Const
2
+ from .utils import Const, jax_gradient
3
3
  from .config import *
4
4
  import numpy as np
5
5
  from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
@@ -133,7 +133,7 @@ class Class_szfast(object):
133
133
  self.geomspace = jnp.geomspace
134
134
  self.arange = jnp.arange
135
135
  self.zeros = jnp.zeros
136
- self.gradient = jnp.gradient
136
+ self.gradient = jax_gradient
137
137
 
138
138
 
139
139
  else:
@@ -591,8 +591,8 @@ class Class_szfast(object):
591
591
 
592
592
  R, var[:,iz] = TophatVar(k, lowring=True)(P[:,iz], extrap=True)
593
593
 
594
- # dvar[:,iz] = self.gradient(var[:,iz], R) ## old form
595
- _, dvar[:,iz] = TophatVar(k, lowring=True, deriv=1)(P[:,iz]*k, extrap=True) # new form
594
+ dvar[:,iz] = self.gradient(var[:,iz], R) ## old form with jax gradient from inigo zubeldia if needed.
595
+ # _, dvar[:,iz] = TophatVar(k, lowring=True, deriv=1)(P[:,iz]*k, extrap=True) # new form doesnt seem to work accurately
596
596
 
597
597
  # dvar = dvar/(2.*np.sqrt(var))
598
598
  # print(dvar_grads/dvar)
classy_szfast/utils.py CHANGED
@@ -13,6 +13,8 @@ import math
13
13
  from numpy import linalg as LA
14
14
  import mcfit
15
15
  from mcfit import P2xi
16
+ import jax
17
+ import jax.numpy as jnp
16
18
  # import cosmopower
17
19
  # import classy_sz as csz
18
20
 
@@ -71,3 +73,150 @@ class Const:
71
73
  _k_B_ = 1.3806504e-23
72
74
  _h_P_ = 6.62606896e-34
73
75
  _M_sun_ = 1.98855e30 # solar mass in kg
76
+
77
+
78
+
79
+ def jax_gradient(f, *varargs, axis=None, edge_order=1):
80
+ f = jnp.asarray(f)
81
+ N = f.ndim # number of dimensions
82
+ if axis is None:
83
+ axes = tuple(range(N))
84
+ else:
85
+ axes = jax.numpy._normalize_axis_index(axis, N)
86
+ len_axes = len(axes)
87
+ n = len(varargs)
88
+ if n == 0:
89
+ # no spacing argument - use 1 in all axes
90
+ dx = [1.0] * len_axes
91
+ elif n == 1 and jnp.ndim(varargs[0]) == 0:
92
+ # single scalar for all axes
93
+ dx = varargs * len_axes
94
+ elif n == len_axes:
95
+ # scalar or 1d array for each axis
96
+ dx = list(varargs)
97
+ for i, distances in enumerate(dx):
98
+ distances = jnp.asarray(distances)
99
+ if distances.ndim == 0:
100
+ continue
101
+ elif distances.ndim != 1:
102
+ raise ValueError("distances must be either scalars or 1D")
103
+ if len(distances) != f.shape[axes[i]]:
104
+ raise ValueError("when 1D, distances must match "
105
+ "the length of the corresponding dimension")
106
+ if jnp.issubdtype(distances.dtype, jnp.integer):
107
+ # Convert jax integer types to float64 to avoid modular
108
+ # arithmetic in np.diff(distances).
109
+ distances = distances.astype(jnp.float64)
110
+ diffx = jnp.diff(distances)
111
+ # if distances are constant reduce to the scalar case
112
+ # since it brings a consistent speedup
113
+ if (diffx == diffx[0]).all():
114
+ diffx = diffx[0]
115
+ dx[i] = diffx
116
+ else:
117
+ raise TypeError("invalid number of arguments")
118
+ if edge_order > 2:
119
+ raise ValueError("'edge_order' greater than 2 not supported")
120
+ outvals = []
121
+ slice1 = [slice(None)] * N
122
+ slice2 = [slice(None)] * N
123
+ slice3 = [slice(None)] * N
124
+ slice4 = [slice(None)] * N
125
+ otype = f.dtype
126
+ # All other types convert to floating point.
127
+ if jnp.issubdtype(otype, jnp.integer):
128
+ f = f.astype(jnp.float64)
129
+ otype = jnp.float64
130
+ for axis, ax_dx in zip(axes, dx):
131
+ if f.shape[axis] < edge_order + 1:
132
+ raise ValueError(
133
+ "Shape of array too small to calculate a numerical gradient, "
134
+ "at least (edge_order + 1) elements are required.")
135
+ # result allocation
136
+ out = jnp.empty_like(f, dtype=otype)
137
+ uniform_spacing = jnp.ndim(ax_dx) == 0
138
+ # Numerical differentiation: 2nd order interior
139
+ slice1[axis] = slice(1, -1)
140
+ slice2[axis] = slice(None, -2)
141
+ slice3[axis] = slice(1, -1)
142
+ slice4[axis] = slice(2, None)
143
+ if uniform_spacing:
144
+ out = out.at[tuple(slice1)].set(
145
+ (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx)
146
+ )
147
+ else:
148
+ dx1 = ax_dx[:-1]
149
+ dx2 = ax_dx[1:]
150
+ a = -(dx2) / (dx1 * (dx1 + dx2))
151
+ b = (dx2 - dx1) / (dx1 * dx2)
152
+ c = dx1 / (dx2 * (dx1 + dx2))
153
+ shape = [1] * N
154
+ shape[axis] = -1
155
+ a = a.reshape(shape)
156
+ b = b.reshape(shape)
157
+ c = c.reshape(shape)
158
+ out = out.at[tuple(slice1)].set(
159
+ a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
160
+ )
161
+ # Numerical differentiation: 1st order edges
162
+ if edge_order == 1:
163
+ slice1[axis] = 0
164
+ slice2[axis] = 1
165
+ slice3[axis] = 0
166
+ dx_0 = ax_dx if uniform_spacing else ax_dx[0]
167
+ out = out.at[tuple(slice1)].set(
168
+ (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
169
+ )
170
+ slice1[axis] = -1
171
+ slice2[axis] = -1
172
+ slice3[axis] = -2
173
+ dx_n = ax_dx if uniform_spacing else ax_dx[-1]
174
+ out = out.at[tuple(slice1)].set(
175
+ (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n
176
+ )
177
+ # Numerical differentiation: 2nd order edges
178
+ else:
179
+ slice1[axis] = 0
180
+ slice2[axis] = 0
181
+ slice3[axis] = 1
182
+ slice4[axis] = 2
183
+ if uniform_spacing:
184
+ a = -1.5 / ax_dx
185
+ b = 2. / ax_dx
186
+ c = -0.5 / ax_dx
187
+ else:
188
+ dx1 = ax_dx[0]
189
+ dx2 = ax_dx[1]
190
+ a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2))
191
+ b = (dx1 + dx2) / (dx1 * dx2)
192
+ c = -dx1 / (dx2 * (dx1 + dx2))
193
+ out = out.at[tuple(slice1)].set(
194
+ a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
195
+ )
196
+ slice1[axis] = -1
197
+ slice2[axis] = -3
198
+ slice3[axis] = -2
199
+ slice4[axis] = -1
200
+ if uniform_spacing:
201
+ a = 0.5 / ax_dx
202
+ b = -2. / ax_dx
203
+ c = 1.5 / ax_dx
204
+ else:
205
+ dx1 = ax_dx[-2]
206
+ dx2 = ax_dx[-1]
207
+ a = dx2 / (dx1 * (dx1 + dx2))
208
+ b = -(dx2 + dx1) / (dx1 * dx2)
209
+ c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2))
210
+ out = out.at[tuple(slice1)].set(
211
+ a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
212
+ )
213
+ outvals.append(out)
214
+ # reset the slice object in this dimension to ":"
215
+ slice1[axis] = slice(None)
216
+ slice2[axis] = slice(None)
217
+ slice3[axis] = slice(None)
218
+ slice4[axis] = slice(None)
219
+ if len_axes == 1:
220
+ return outvals[0]
221
+ ret = tuple(outvals)
222
+ print("return",ret)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post10
3
+ Version: 0.0.25.post12
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -10,7 +10,7 @@ Description-Content-Type: text/markdown
10
10
  Requires-Dist: numpy>=1.19.0
11
11
  Requires-Dist: Cython>=0.29.21
12
12
  Requires-Dist: tensorflow
13
- Requires-Dist: mcfit
14
- Requires-Dist: get_cosmopower_emus==0.0.15
15
- Requires-Dist: class_sz_data
13
+ Requires-Dist: mcfit>=0.0.22
14
+ Requires-Dist: get_cosmopower_emus>=0.0.15
15
+ Requires-Dist: class_sz_data>=0.0.13
16
16
  Requires-Dist: cosmopower-jax
@@ -1,6 +1,6 @@
1
1
  classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
2
2
  classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
3
- classy_szfast/classy_szfast.py,sha256=ccKynjQks6nrPXa-eLREmCoIz_cl6e6dVJxdcdqavAE,42821
3
+ classy_szfast/classy_szfast.py,sha256=18p1wwBLFKTSXh2iCHoSSVD5NNwsFMyrOOrpbnsBcyk,42915
4
4
  classy_szfast/config.py,sha256=v6DGcBHmfn5JtuO48dKyXCh-Dmn0uwOF_izvVOJFnqw,279
5
5
  classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
6
6
  classy_szfast/cosmopower_jax.py,sha256=NqU_Sw5x0qw16DNHrN4ipCZP8HVdQvUiCq_Y04h8zoo,5468
@@ -9,12 +9,12 @@ classy_szfast/emulators_meta_data.py,sha256=mXG5LQuJw9QBNE_kxXW8Kx0AUCWpbV6uRO9B
9
9
  classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
10
10
  classy_szfast/restore_nn.py,sha256=DqA9thhTRiGBDVb9zjhqcbF2W4V0AU0vrjJFhnLboU4,21075
11
11
  classy_szfast/suppress_warnings.py,sha256=6wIBml2Sj9DyRGZlZWhuA9hqvpxqrNyYjuz6BPK_a6E,202
12
- classy_szfast/utils.py,sha256=pDZSIgmII3J5bMUYAD8ErvngyeJEEhtla12kCIAWkMg,1962
12
+ classy_szfast/utils.py,sha256=hdA4_ZqCTuO6YODptpROVTns4jYxJqX6rDU9FHmJ1NA,7502
13
13
  classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
15
15
  classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
17
- classy_szfast-0.0.25.post10.dist-info/METADATA,sha256=RycC7LMaSonaPaonweilq5GM_AfGs4sNbcZrkZ9QMNM,556
18
- classy_szfast-0.0.25.post10.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
19
- classy_szfast-0.0.25.post10.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
- classy_szfast-0.0.25.post10.dist-info/RECORD,,
17
+ classy_szfast-0.0.25.post12.dist-info/METADATA,sha256=xvy4-nx-2tsgs6ol1CPjWoMyOsBli_PSBR0YSqC088U,572
18
+ classy_szfast-0.0.25.post12.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
19
+ classy_szfast-0.0.25.post12.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
+ classy_szfast-0.0.25.post12.dist-info/RECORD,,