classy-szfast 0.0.25.post2__py3-none-any.whl → 0.0.25.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- classy_szfast/classy_szfast.py +37 -20
- classy_szfast/cosmopower_jax.py +77 -3
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post3.dist-info}/METADATA +1 -1
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post3.dist-info}/RECORD +6 -6
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post3.dist-info}/WHEEL +0 -0
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post3.dist-info}/top_level.txt +0 -0
classy_szfast/classy_szfast.py
CHANGED
@@ -92,12 +92,14 @@ class Class_szfast(object):
|
|
92
92
|
self.cp_pknl_nn = cp_pknl_nn
|
93
93
|
self.cp_pkl_nn = cp_pkl_nn
|
94
94
|
self.cp_der_nn = cp_der_nn
|
95
|
-
|
95
|
+
|
96
96
|
|
97
97
|
if self.jax_mode:
|
98
98
|
self.cp_h_nn = cp_h_nn_jax
|
99
|
+
self.cp_da_nn = cp_da_nn_jax
|
99
100
|
else:
|
100
101
|
self.cp_h_nn = cp_h_nn
|
102
|
+
self.cp_da_nn = cp_da_nn
|
101
103
|
|
102
104
|
self.cp_s8_nn = cp_s8_nn
|
103
105
|
|
@@ -785,35 +787,50 @@ class Class_szfast(object):
|
|
785
787
|
update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
|
786
788
|
|
787
789
|
params_dict = {}
|
788
|
-
|
789
790
|
for k,v in zip(params_values.keys(),params_values.values()):
|
790
|
-
|
791
791
|
params_dict[k]=[v]
|
792
792
|
|
793
793
|
if 'm_ncdm' in params_dict.keys():
|
794
794
|
if isinstance(params_dict['m_ncdm'][0],str):
|
795
795
|
params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
|
796
796
|
|
797
|
-
|
798
|
-
|
797
|
+
if self.jax_mode:
|
798
|
+
# print("JAX MODE in chi")
|
799
|
+
# self.cp_da_nn[self.cosmo_model].log = False
|
800
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
|
801
|
+
# print("self.cp_predicted_da",self.cp_predicted_da)
|
802
|
+
|
803
|
+
if self.cosmo_model == 'ede-v2':
|
804
|
+
# print('ede-v2 case')
|
805
|
+
self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
|
806
|
+
self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
|
807
|
+
|
808
|
+
def chi_interp(x):
|
809
|
+
return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
|
810
|
+
|
811
|
+
self.chi_interp = chi_interp
|
799
812
|
|
800
|
-
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
801
|
-
self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
|
802
|
-
|
803
813
|
else:
|
804
|
-
|
805
|
-
|
806
|
-
|
814
|
+
# deal with different scaling of DA in different model from emulator training
|
815
|
+
if self.cosmo_model == 'ede-v2':
|
807
816
|
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
818
|
+
self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
|
819
|
+
|
820
|
+
else:
|
821
|
+
|
822
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
|
823
|
+
|
824
|
+
|
825
|
+
self.chi_interp = scipy.interpolate.interp1d(
|
826
|
+
self.cp_z_interp,
|
827
|
+
self.cp_predicted_da*(1.+self.cp_z_interp),
|
828
|
+
kind='linear',
|
829
|
+
axis=-1,
|
830
|
+
copy=True,
|
831
|
+
bounds_error=None,
|
832
|
+
fill_value=np.nan,
|
833
|
+
assume_sorted=False)
|
817
834
|
|
818
835
|
def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
|
819
836
|
|
classy_szfast/cosmopower_jax.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from .config import path_to_class_sz_data
|
2
2
|
import numpy as np
|
3
|
+
import jax.numpy as jnp
|
3
4
|
from .restore_nn import Restore_NN
|
4
5
|
from .restore_nn import Restore_PCAplusNN
|
5
6
|
from .suppress_warnings import suppress_warnings
|
@@ -20,6 +21,74 @@ cp_h_nn_jax = {}
|
|
20
21
|
cp_s8_nn_jax = {}
|
21
22
|
|
22
23
|
|
24
|
+
class CosmoPowerJAX_custom(CPJ):
|
25
|
+
def __init__(self, *args, **kwargs):
|
26
|
+
super().__init__(*args, **kwargs)
|
27
|
+
self.ten_to_predictions = True
|
28
|
+
if 'ten_to_predictions' in kwargs.keys():
|
29
|
+
self.ten_to_predictions = kwargs['ten_to_predictions']
|
30
|
+
|
31
|
+
def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
|
32
|
+
feature_train_mean, feature_train_std, input_vec):
|
33
|
+
""" Forward pass through pre-trained network.
|
34
|
+
In its current form, it does not make use of high-level frameworks like
|
35
|
+
FLAX et similia; rather, it simply loops over the network layers.
|
36
|
+
In future work this can be improved, especially if speed is a problem.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
weights : array
|
41
|
+
The stored weights of the neural network.
|
42
|
+
hyper_params : array
|
43
|
+
The stored hyperparameters of the activation function for each layer.
|
44
|
+
param_train_mean : array
|
45
|
+
The stored mean of the training cosmological parameters.
|
46
|
+
param_train_std : array
|
47
|
+
The stored standard deviation of the training cosmological parameters.
|
48
|
+
feature_train_mean : array
|
49
|
+
The stored mean of the training features.
|
50
|
+
feature_train_std : array
|
51
|
+
The stored standard deviation of the training features.
|
52
|
+
input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
|
53
|
+
The cosmological parameters given as input to the network.
|
54
|
+
|
55
|
+
Returns
|
56
|
+
-------
|
57
|
+
predictions : array
|
58
|
+
The prediction of the trained neural network.
|
59
|
+
"""
|
60
|
+
act = []
|
61
|
+
# Standardise
|
62
|
+
layer_out = [(input_vec - param_train_mean)/param_train_std]
|
63
|
+
|
64
|
+
# Loop over layers
|
65
|
+
for i in range(len(weights[:-1])):
|
66
|
+
w, b = weights[i]
|
67
|
+
alpha, beta = hyper_params[i]
|
68
|
+
act.append(jnp.dot(layer_out[-1], w.T) + b)
|
69
|
+
layer_out.append(self._activation(act[-1], alpha, beta))
|
70
|
+
|
71
|
+
# Final layer prediction (no activations)
|
72
|
+
w, b = weights[-1]
|
73
|
+
if self.probe == 'custom_log' or self.probe == 'custom_pca':
|
74
|
+
# in original CP models, we assumed a full final bias vector...
|
75
|
+
preds = jnp.dot(layer_out[-1], w.T) + b
|
76
|
+
else:
|
77
|
+
# ... unlike in cpjax, where we used only a single bias vector
|
78
|
+
preds = jnp.dot(layer_out[-1], w.T) + b[-1]
|
79
|
+
|
80
|
+
# Undo the standardisation
|
81
|
+
preds = preds * feature_train_std + feature_train_mean
|
82
|
+
if self.log == True:
|
83
|
+
if self.ten_to_predictions:
|
84
|
+
preds = 10**preds
|
85
|
+
else:
|
86
|
+
preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
|
87
|
+
if self.probe == 'cmb_pp':
|
88
|
+
preds = 10**preds
|
89
|
+
predictions = preds.squeeze()
|
90
|
+
return predictions
|
91
|
+
|
23
92
|
for mp in cosmo_model_list:
|
24
93
|
folder, version = split_emulator_string(mp)
|
25
94
|
# print(folder, version)
|
@@ -40,14 +109,19 @@ for mp in cosmo_model_list:
|
|
40
109
|
|
41
110
|
cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
|
42
111
|
|
43
|
-
cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
112
|
+
# cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
113
|
+
|
44
114
|
|
115
|
+
cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
|
116
|
+
if mp != 'ede-v2':
|
117
|
+
cp_da_nn_jax[mp].ten_to_predictions = False
|
118
|
+
# print(cp_da_nn_jax[mp].parameters)
|
45
119
|
# print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
|
46
|
-
emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
120
|
+
# emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
47
121
|
# print(emulator_custom.parameters)
|
48
122
|
# exit()
|
49
123
|
|
50
|
-
cp_h_nn_jax[mp] =
|
124
|
+
cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
51
125
|
|
52
126
|
cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
|
53
127
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
|
2
2
|
classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
|
3
|
-
classy_szfast/classy_szfast.py,sha256=
|
3
|
+
classy_szfast/classy_szfast.py,sha256=QGXqyzMFdF8mT0RxpJxwBswvM-WRmYFgRkfzUawTmH0,37459
|
4
4
|
classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
|
5
5
|
classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
|
6
|
-
classy_szfast/cosmopower_jax.py,sha256=
|
6
|
+
classy_szfast/cosmopower_jax.py,sha256=Bmyzl15OH4KwmpcXfWm1s4XlvlKKkDfgMy70zhnTV5I,5362
|
7
7
|
classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
|
8
8
|
classy_szfast/emulators_meta_data.py,sha256=-lHneGhSJ2481S48viz_bNeCyAGu1Ogee0jFEB8B618,9724
|
9
9
|
classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
|
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
14
14
|
classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
|
15
15
|
classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
|
17
|
-
classy_szfast-0.0.25.
|
18
|
-
classy_szfast-0.0.25.
|
19
|
-
classy_szfast-0.0.25.
|
20
|
-
classy_szfast-0.0.25.
|
17
|
+
classy_szfast-0.0.25.post3.dist-info/METADATA,sha256=JEdn1Go9uH7wuk0OWfoLzX9fRSKNsVI_fH9SF6AM9vQ,548
|
18
|
+
classy_szfast-0.0.25.post3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
19
|
+
classy_szfast-0.0.25.post3.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
|
20
|
+
classy_szfast-0.0.25.post3.dist-info/RECORD,,
|
File without changes
|
File without changes
|