classy-szfast 0.0.25.post26__py3-none-any.whl → 0.0.25.post27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- classy_szfast/classy_szfast.py +14 -4
- classy_szfast/cosmopower_jax.py +6 -2
- {classy_szfast-0.0.25.post26.dist-info → classy_szfast-0.0.25.post27.dist-info}/METADATA +2 -2
- {classy_szfast-0.0.25.post26.dist-info → classy_szfast-0.0.25.post27.dist-info}/RECORD +6 -6
- {classy_szfast-0.0.25.post26.dist-info → classy_szfast-0.0.25.post27.dist-info}/WHEEL +1 -1
- {classy_szfast-0.0.25.post26.dist-info → classy_szfast-0.0.25.post27.dist-info}/top_level.txt +0 -0
classy_szfast/classy_szfast.py
CHANGED
@@ -122,6 +122,7 @@ class Class_szfast(object):
|
|
122
122
|
self.cp_h_nn = cp_h_nn_jax
|
123
123
|
self.cp_da_nn = cp_da_nn_jax
|
124
124
|
self.cp_pkl_nn = cp_pkl_nn_jax
|
125
|
+
self.cp_pknl_nn = cp_pknl_nn_jax
|
125
126
|
self.cp_der_nn = cp_der_nn_jax
|
126
127
|
|
127
128
|
self.pi = jnp.pi
|
@@ -153,6 +154,7 @@ class Class_szfast(object):
|
|
153
154
|
self.cp_h_nn = cp_h_nn
|
154
155
|
self.cp_da_nn = cp_da_nn
|
155
156
|
self.cp_pkl_nn = cp_pkl_nn
|
157
|
+
self.cp_pknl_nn = cp_pknl_nn
|
156
158
|
self.cp_der_nn = cp_der_nn
|
157
159
|
self.pi = np.pi
|
158
160
|
self.transpose = np.transpose
|
@@ -779,7 +781,10 @@ class Class_szfast(object):
|
|
779
781
|
for zp in z_arr:
|
780
782
|
params_dict_pp = params_dict.copy()
|
781
783
|
params_dict_pp['z_pk_save_nonclass'] = [zp]
|
782
|
-
|
784
|
+
if self.jax_mode:
|
785
|
+
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predict(params_dict_pp))
|
786
|
+
else:
|
787
|
+
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
783
788
|
|
784
789
|
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
785
790
|
|
@@ -790,7 +795,10 @@ class Class_szfast(object):
|
|
790
795
|
pk_re = self.transpose(pk_re)
|
791
796
|
|
792
797
|
|
793
|
-
|
798
|
+
if self.jax_mode:
|
799
|
+
self.pknl_interp = None
|
800
|
+
else:
|
801
|
+
self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,self.log(pk_re).T,logP=True)
|
794
802
|
|
795
803
|
|
796
804
|
self.cszfast_pk_grid_pknl = pk_re
|
@@ -877,12 +885,14 @@ class Class_szfast(object):
|
|
877
885
|
|
878
886
|
predicted_pk_spectrum_z = []
|
879
887
|
|
880
|
-
z_asked = z_asked
|
881
888
|
params_dict_pp = params_dict.copy()
|
882
889
|
update_params_with_defaults(params_dict_pp, self.emulator_dict[self.cosmo_model]['default'])
|
883
890
|
|
884
891
|
params_dict_pp['z_pk_save_nonclass'] = [z_asked]
|
885
|
-
|
892
|
+
if self.jax_mode:
|
893
|
+
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predict(params_dict_pp))
|
894
|
+
else:
|
895
|
+
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
886
896
|
|
887
897
|
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
888
898
|
|
classy_szfast/cosmopower_jax.py
CHANGED
@@ -132,8 +132,12 @@ for mp in cosmo_model_list:
|
|
132
132
|
|
133
133
|
cp_pp_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
|
134
134
|
|
135
|
-
cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
|
136
|
-
|
135
|
+
# cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
|
136
|
+
|
137
|
+
cp_pknl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKNL'] + '.npz')
|
138
|
+
cp_pknl_nn_jax[mp].ten_to_predictions = False
|
139
|
+
|
140
|
+
|
137
141
|
cp_pkl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKL'] + '.npz')
|
138
142
|
cp_pkl_nn_jax[mp].ten_to_predictions = False
|
139
143
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
classy_szfast/__init__.py,sha256=8XVvg_q0kf97P_glT4o-6Jnss4AG4RgHat4CEslWDKs,305
|
2
2
|
classy_szfast/classy_sz.py,sha256=SCzZYCUSR0c-DAGheX-OlgzzOxOc24d5T4b-ezdhubU,25512
|
3
|
-
classy_szfast/classy_szfast.py,sha256=
|
3
|
+
classy_szfast/classy_szfast.py,sha256=x-svNHmiyhsfH3p0MAoCe86fs0EghqQpLkqIJLyWGp8,45776
|
4
4
|
classy_szfast/config.py,sha256=v6DGcBHmfn5JtuO48dKyXCh-Dmn0uwOF_izvVOJFnqw,279
|
5
5
|
classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
|
6
|
-
classy_szfast/cosmopower_jax.py,sha256=
|
6
|
+
classy_szfast/cosmopower_jax.py,sha256=CWrfXG0W4FQ2FbzRz7X7ZZoh6FbDT56XU0g2QFunna4,6680
|
7
7
|
classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
|
8
8
|
classy_szfast/emulators_meta_data.py,sha256=mXG5LQuJw9QBNE_kxXW8Kx0AUCWpbV6uRO9BaBbIfHo,10732
|
9
9
|
classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
|
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
14
14
|
classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
|
15
15
|
classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
|
17
|
-
classy_szfast-0.0.25.
|
18
|
-
classy_szfast-0.0.25.
|
19
|
-
classy_szfast-0.0.25.
|
20
|
-
classy_szfast-0.0.25.
|
17
|
+
classy_szfast-0.0.25.post27.dist-info/METADATA,sha256=FUM3Nf3QS2nf9c4CM4FW9vwCAXQl1OBJ6j8K6XrGwIg,579
|
18
|
+
classy_szfast-0.0.25.post27.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
19
|
+
classy_szfast-0.0.25.post27.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
|
20
|
+
classy_szfast-0.0.25.post27.dist-info/RECORD,,
|
{classy_szfast-0.0.25.post26.dist-info → classy_szfast-0.0.25.post27.dist-info}/top_level.txt
RENAMED
File without changes
|