classy-szfast 0.0.25.post1__py3-none-any.whl → 0.0.25.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- classy_szfast/classy_szfast.py +39 -20
- classy_szfast/cosmopower_jax.py +77 -3
- {classy_szfast-0.0.25.post1.dist-info → classy_szfast-0.0.25.post3.dist-info}/METADATA +1 -1
- {classy_szfast-0.0.25.post1.dist-info → classy_szfast-0.0.25.post3.dist-info}/RECORD +6 -6
- {classy_szfast-0.0.25.post1.dist-info → classy_szfast-0.0.25.post3.dist-info}/WHEEL +0 -0
- {classy_szfast-0.0.25.post1.dist-info → classy_szfast-0.0.25.post3.dist-info}/top_level.txt +0 -0
classy_szfast/classy_szfast.py
CHANGED
@@ -92,12 +92,14 @@ class Class_szfast(object):
|
|
92
92
|
self.cp_pknl_nn = cp_pknl_nn
|
93
93
|
self.cp_pkl_nn = cp_pkl_nn
|
94
94
|
self.cp_der_nn = cp_der_nn
|
95
|
-
|
95
|
+
|
96
96
|
|
97
97
|
if self.jax_mode:
|
98
98
|
self.cp_h_nn = cp_h_nn_jax
|
99
|
+
self.cp_da_nn = cp_da_nn_jax
|
99
100
|
else:
|
100
101
|
self.cp_h_nn = cp_h_nn
|
102
|
+
self.cp_da_nn = cp_da_nn
|
101
103
|
|
102
104
|
self.cp_s8_nn = cp_s8_nn
|
103
105
|
|
@@ -741,7 +743,9 @@ class Class_szfast(object):
|
|
741
743
|
# print("JAX MODE in hubble")
|
742
744
|
# self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
743
745
|
# print(params_dict)
|
746
|
+
# print("params_dict type:", type(params_dict))
|
744
747
|
self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].predict(params_dict)
|
748
|
+
# print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
|
745
749
|
# print("self.cp_predicted_hubble",self.cp_predicted_hubble)
|
746
750
|
|
747
751
|
# self.hz_interp = jscipy.interpolate.interp1d(
|
@@ -783,35 +787,50 @@ class Class_szfast(object):
|
|
783
787
|
update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
|
784
788
|
|
785
789
|
params_dict = {}
|
786
|
-
|
787
790
|
for k,v in zip(params_values.keys(),params_values.values()):
|
788
|
-
|
789
791
|
params_dict[k]=[v]
|
790
792
|
|
791
793
|
if 'm_ncdm' in params_dict.keys():
|
792
794
|
if isinstance(params_dict['m_ncdm'][0],str):
|
793
795
|
params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
|
794
796
|
|
795
|
-
|
796
|
-
|
797
|
+
if self.jax_mode:
|
798
|
+
# print("JAX MODE in chi")
|
799
|
+
# self.cp_da_nn[self.cosmo_model].log = False
|
800
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
|
801
|
+
# print("self.cp_predicted_da",self.cp_predicted_da)
|
802
|
+
|
803
|
+
if self.cosmo_model == 'ede-v2':
|
804
|
+
# print('ede-v2 case')
|
805
|
+
self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
|
806
|
+
self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
|
807
|
+
|
808
|
+
def chi_interp(x):
|
809
|
+
return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
|
810
|
+
|
811
|
+
self.chi_interp = chi_interp
|
797
812
|
|
798
|
-
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
799
|
-
self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
|
800
|
-
|
801
813
|
else:
|
802
|
-
|
803
|
-
|
804
|
-
|
814
|
+
# deal with different scaling of DA in different model from emulator training
|
815
|
+
if self.cosmo_model == 'ede-v2':
|
805
816
|
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
817
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
818
|
+
self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
|
819
|
+
|
820
|
+
else:
|
821
|
+
|
822
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
|
823
|
+
|
824
|
+
|
825
|
+
self.chi_interp = scipy.interpolate.interp1d(
|
826
|
+
self.cp_z_interp,
|
827
|
+
self.cp_predicted_da*(1.+self.cp_z_interp),
|
828
|
+
kind='linear',
|
829
|
+
axis=-1,
|
830
|
+
copy=True,
|
831
|
+
bounds_error=None,
|
832
|
+
fill_value=np.nan,
|
833
|
+
assume_sorted=False)
|
815
834
|
|
816
835
|
def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
|
817
836
|
|
classy_szfast/cosmopower_jax.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from .config import path_to_class_sz_data
|
2
2
|
import numpy as np
|
3
|
+
import jax.numpy as jnp
|
3
4
|
from .restore_nn import Restore_NN
|
4
5
|
from .restore_nn import Restore_PCAplusNN
|
5
6
|
from .suppress_warnings import suppress_warnings
|
@@ -20,6 +21,74 @@ cp_h_nn_jax = {}
|
|
20
21
|
cp_s8_nn_jax = {}
|
21
22
|
|
22
23
|
|
24
|
+
class CosmoPowerJAX_custom(CPJ):
|
25
|
+
def __init__(self, *args, **kwargs):
|
26
|
+
super().__init__(*args, **kwargs)
|
27
|
+
self.ten_to_predictions = True
|
28
|
+
if 'ten_to_predictions' in kwargs.keys():
|
29
|
+
self.ten_to_predictions = kwargs['ten_to_predictions']
|
30
|
+
|
31
|
+
def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
|
32
|
+
feature_train_mean, feature_train_std, input_vec):
|
33
|
+
""" Forward pass through pre-trained network.
|
34
|
+
In its current form, it does not make use of high-level frameworks like
|
35
|
+
FLAX et similia; rather, it simply loops over the network layers.
|
36
|
+
In future work this can be improved, especially if speed is a problem.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
weights : array
|
41
|
+
The stored weights of the neural network.
|
42
|
+
hyper_params : array
|
43
|
+
The stored hyperparameters of the activation function for each layer.
|
44
|
+
param_train_mean : array
|
45
|
+
The stored mean of the training cosmological parameters.
|
46
|
+
param_train_std : array
|
47
|
+
The stored standard deviation of the training cosmological parameters.
|
48
|
+
feature_train_mean : array
|
49
|
+
The stored mean of the training features.
|
50
|
+
feature_train_std : array
|
51
|
+
The stored standard deviation of the training features.
|
52
|
+
input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
|
53
|
+
The cosmological parameters given as input to the network.
|
54
|
+
|
55
|
+
Returns
|
56
|
+
-------
|
57
|
+
predictions : array
|
58
|
+
The prediction of the trained neural network.
|
59
|
+
"""
|
60
|
+
act = []
|
61
|
+
# Standardise
|
62
|
+
layer_out = [(input_vec - param_train_mean)/param_train_std]
|
63
|
+
|
64
|
+
# Loop over layers
|
65
|
+
for i in range(len(weights[:-1])):
|
66
|
+
w, b = weights[i]
|
67
|
+
alpha, beta = hyper_params[i]
|
68
|
+
act.append(jnp.dot(layer_out[-1], w.T) + b)
|
69
|
+
layer_out.append(self._activation(act[-1], alpha, beta))
|
70
|
+
|
71
|
+
# Final layer prediction (no activations)
|
72
|
+
w, b = weights[-1]
|
73
|
+
if self.probe == 'custom_log' or self.probe == 'custom_pca':
|
74
|
+
# in original CP models, we assumed a full final bias vector...
|
75
|
+
preds = jnp.dot(layer_out[-1], w.T) + b
|
76
|
+
else:
|
77
|
+
# ... unlike in cpjax, where we used only a single bias vector
|
78
|
+
preds = jnp.dot(layer_out[-1], w.T) + b[-1]
|
79
|
+
|
80
|
+
# Undo the standardisation
|
81
|
+
preds = preds * feature_train_std + feature_train_mean
|
82
|
+
if self.log == True:
|
83
|
+
if self.ten_to_predictions:
|
84
|
+
preds = 10**preds
|
85
|
+
else:
|
86
|
+
preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
|
87
|
+
if self.probe == 'cmb_pp':
|
88
|
+
preds = 10**preds
|
89
|
+
predictions = preds.squeeze()
|
90
|
+
return predictions
|
91
|
+
|
23
92
|
for mp in cosmo_model_list:
|
24
93
|
folder, version = split_emulator_string(mp)
|
25
94
|
# print(folder, version)
|
@@ -40,14 +109,19 @@ for mp in cosmo_model_list:
|
|
40
109
|
|
41
110
|
cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
|
42
111
|
|
43
|
-
cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
112
|
+
# cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
113
|
+
|
44
114
|
|
115
|
+
cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
|
116
|
+
if mp != 'ede-v2':
|
117
|
+
cp_da_nn_jax[mp].ten_to_predictions = False
|
118
|
+
# print(cp_da_nn_jax[mp].parameters)
|
45
119
|
# print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
|
46
|
-
emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
120
|
+
# emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
47
121
|
# print(emulator_custom.parameters)
|
48
122
|
# exit()
|
49
123
|
|
50
|
-
cp_h_nn_jax[mp] =
|
124
|
+
cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
51
125
|
|
52
126
|
cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
|
53
127
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
|
2
2
|
classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
|
3
|
-
classy_szfast/classy_szfast.py,sha256=
|
3
|
+
classy_szfast/classy_szfast.py,sha256=QGXqyzMFdF8mT0RxpJxwBswvM-WRmYFgRkfzUawTmH0,37459
|
4
4
|
classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
|
5
5
|
classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
|
6
|
-
classy_szfast/cosmopower_jax.py,sha256=
|
6
|
+
classy_szfast/cosmopower_jax.py,sha256=Bmyzl15OH4KwmpcXfWm1s4XlvlKKkDfgMy70zhnTV5I,5362
|
7
7
|
classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
|
8
8
|
classy_szfast/emulators_meta_data.py,sha256=-lHneGhSJ2481S48viz_bNeCyAGu1Ogee0jFEB8B618,9724
|
9
9
|
classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
|
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
14
14
|
classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
|
15
15
|
classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
|
17
|
-
classy_szfast-0.0.25.
|
18
|
-
classy_szfast-0.0.25.
|
19
|
-
classy_szfast-0.0.25.
|
20
|
-
classy_szfast-0.0.25.
|
17
|
+
classy_szfast-0.0.25.post3.dist-info/METADATA,sha256=JEdn1Go9uH7wuk0OWfoLzX9fRSKNsVI_fH9SF6AM9vQ,548
|
18
|
+
classy_szfast-0.0.25.post3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
19
|
+
classy_szfast-0.0.25.post3.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
|
20
|
+
classy_szfast-0.0.25.post3.dist-info/RECORD,,
|
File without changes
|
File without changes
|