classy-szfast 0.0.25.post2__tar.gz → 0.0.25.post3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/PKG-INFO +1 -1
  2. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/classy_szfast.py +37 -20
  3. classy_szfast-0.0.25.post3/classy_szfast/cosmopower_jax.py +127 -0
  4. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast.egg-info/PKG-INFO +1 -1
  5. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/pyproject.toml +1 -1
  6. classy_szfast-0.0.25.post2/classy_szfast/cosmopower_jax.py +0 -53
  7. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/README.md +0 -0
  8. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/__init__.py +0 -0
  9. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/classy_sz.py +0 -0
  10. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/config.py +0 -0
  11. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/cosmopower.py +0 -0
  12. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  13. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/custom_bias/__init__.py +0 -0
  14. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/custom_bias/custom_bias.py +0 -0
  15. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/custom_profiles/__init__.py +0 -0
  16. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  17. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/emulators_meta_data.py +0 -0
  18. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/pks_and_sigmas.py +0 -0
  19. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/restore_nn.py +0 -0
  20. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/suppress_warnings.py +0 -0
  21. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast/utils.py +0 -0
  22. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast.egg-info/SOURCES.txt +0 -0
  23. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast.egg-info/dependency_links.txt +0 -0
  24. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast.egg-info/requires.txt +0 -0
  25. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/classy_szfast.egg-info/top_level.txt +0 -0
  26. {classy_szfast-0.0.25.post2 → classy_szfast-0.0.25.post3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post2
3
+ Version: 0.0.25.post3
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -92,12 +92,14 @@ class Class_szfast(object):
92
92
  self.cp_pknl_nn = cp_pknl_nn
93
93
  self.cp_pkl_nn = cp_pkl_nn
94
94
  self.cp_der_nn = cp_der_nn
95
- self.cp_da_nn = cp_da_nn
95
+
96
96
 
97
97
  if self.jax_mode:
98
98
  self.cp_h_nn = cp_h_nn_jax
99
+ self.cp_da_nn = cp_da_nn_jax
99
100
  else:
100
101
  self.cp_h_nn = cp_h_nn
102
+ self.cp_da_nn = cp_da_nn
101
103
 
102
104
  self.cp_s8_nn = cp_s8_nn
103
105
 
@@ -785,35 +787,50 @@ class Class_szfast(object):
785
787
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
786
788
 
787
789
  params_dict = {}
788
-
789
790
  for k,v in zip(params_values.keys(),params_values.values()):
790
-
791
791
  params_dict[k]=[v]
792
792
 
793
793
  if 'm_ncdm' in params_dict.keys():
794
794
  if isinstance(params_dict['m_ncdm'][0],str):
795
795
  params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
796
796
 
797
- # deal with different scaling of DA in different model from emulator training
798
- if self.cosmo_model == 'ede-v2':
797
+ if self.jax_mode:
798
+ # print("JAX MODE in chi")
799
+ # self.cp_da_nn[self.cosmo_model].log = False
800
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
801
+ # print("self.cp_predicted_da",self.cp_predicted_da)
802
+
803
+ if self.cosmo_model == 'ede-v2':
804
+ # print('ede-v2 case')
805
+ self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
806
+ self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
807
+
808
+ def chi_interp(x):
809
+ return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
810
+
811
+ self.chi_interp = chi_interp
799
812
 
800
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
801
- self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
802
-
803
813
  else:
804
-
805
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
806
-
814
+ # deal with different scaling of DA in different model from emulator training
815
+ if self.cosmo_model == 'ede-v2':
807
816
 
808
- self.chi_interp = scipy.interpolate.interp1d(
809
- self.cp_z_interp,
810
- self.cp_predicted_da*(1.+self.cp_z_interp),
811
- kind='linear',
812
- axis=-1,
813
- copy=True,
814
- bounds_error=None,
815
- fill_value=np.nan,
816
- assume_sorted=False)
817
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
818
+ self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
819
+
820
+ else:
821
+
822
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
823
+
824
+
825
+ self.chi_interp = scipy.interpolate.interp1d(
826
+ self.cp_z_interp,
827
+ self.cp_predicted_da*(1.+self.cp_z_interp),
828
+ kind='linear',
829
+ axis=-1,
830
+ copy=True,
831
+ bounds_error=None,
832
+ fill_value=np.nan,
833
+ assume_sorted=False)
817
834
 
818
835
  def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
819
836
 
@@ -0,0 +1,127 @@
1
+ from .config import path_to_class_sz_data
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from .restore_nn import Restore_NN
5
+ from .restore_nn import Restore_PCAplusNN
6
+ from .suppress_warnings import suppress_warnings
7
+ from .emulators_meta_data import *
8
+
9
+ from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
10
+
11
+
12
+ cp_tt_nn_jax = {}
13
+ cp_te_nn_jax = {}
14
+ cp_ee_nn_jax = {}
15
+ cp_pp_nn_jax = {}
16
+ cp_pknl_nn_jax = {}
17
+ cp_pkl_nn_jax = {}
18
+ cp_der_nn_jax = {}
19
+ cp_da_nn_jax = {}
20
+ cp_h_nn_jax = {}
21
+ cp_s8_nn_jax = {}
22
+
23
+
24
+ class CosmoPowerJAX_custom(CPJ):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.ten_to_predictions = True
28
+ if 'ten_to_predictions' in kwargs.keys():
29
+ self.ten_to_predictions = kwargs['ten_to_predictions']
30
+
31
+ def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
32
+ feature_train_mean, feature_train_std, input_vec):
33
+ """ Forward pass through pre-trained network.
34
+ In its current form, it does not make use of high-level frameworks like
35
+ FLAX et similia; rather, it simply loops over the network layers.
36
+ In future work this can be improved, especially if speed is a problem.
37
+
38
+ Parameters
39
+ ----------
40
+ weights : array
41
+ The stored weights of the neural network.
42
+ hyper_params : array
43
+ The stored hyperparameters of the activation function for each layer.
44
+ param_train_mean : array
45
+ The stored mean of the training cosmological parameters.
46
+ param_train_std : array
47
+ The stored standard deviation of the training cosmological parameters.
48
+ feature_train_mean : array
49
+ The stored mean of the training features.
50
+ feature_train_std : array
51
+ The stored standard deviation of the training features.
52
+ input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
53
+ The cosmological parameters given as input to the network.
54
+
55
+ Returns
56
+ -------
57
+ predictions : array
58
+ The prediction of the trained neural network.
59
+ """
60
+ act = []
61
+ # Standardise
62
+ layer_out = [(input_vec - param_train_mean)/param_train_std]
63
+
64
+ # Loop over layers
65
+ for i in range(len(weights[:-1])):
66
+ w, b = weights[i]
67
+ alpha, beta = hyper_params[i]
68
+ act.append(jnp.dot(layer_out[-1], w.T) + b)
69
+ layer_out.append(self._activation(act[-1], alpha, beta))
70
+
71
+ # Final layer prediction (no activations)
72
+ w, b = weights[-1]
73
+ if self.probe == 'custom_log' or self.probe == 'custom_pca':
74
+ # in original CP models, we assumed a full final bias vector...
75
+ preds = jnp.dot(layer_out[-1], w.T) + b
76
+ else:
77
+ # ... unlike in cpjax, where we used only a single bias vector
78
+ preds = jnp.dot(layer_out[-1], w.T) + b[-1]
79
+
80
+ # Undo the standardisation
81
+ preds = preds * feature_train_std + feature_train_mean
82
+ if self.log == True:
83
+ if self.ten_to_predictions:
84
+ preds = 10**preds
85
+ else:
86
+ preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
87
+ if self.probe == 'cmb_pp':
88
+ preds = 10**preds
89
+ predictions = preds.squeeze()
90
+ return predictions
91
+
92
+ for mp in cosmo_model_list:
93
+ folder, version = split_emulator_string(mp)
94
+ # print(folder, version)
95
+ path_to_emulators = path_to_class_sz_data + '/' + folder +'/'
96
+
97
+ cp_tt_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TT'])
98
+
99
+ cp_te_nn_jax[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TE'])
100
+
101
+ with suppress_warnings():
102
+ cp_ee_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['EE'])
103
+
104
+ cp_pp_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
105
+
106
+ cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
107
+
108
+ cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
109
+
110
+ cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
111
+
112
+ # cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
113
+
114
+
115
+ cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
116
+ if mp != 'ede-v2':
117
+ cp_da_nn_jax[mp].ten_to_predictions = False
118
+ # print(cp_da_nn_jax[mp].parameters)
119
+ # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
120
+ # emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
121
+ # print(emulator_custom.parameters)
122
+ # exit()
123
+
124
+ cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
125
+
126
+ cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
127
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post2
3
+ Version: 0.0.25.post3
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.25.post2"
6
+ version = "0.0.25.post3"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]
@@ -1,53 +0,0 @@
1
- from .config import path_to_class_sz_data
2
- import numpy as np
3
- from .restore_nn import Restore_NN
4
- from .restore_nn import Restore_PCAplusNN
5
- from .suppress_warnings import suppress_warnings
6
- from .emulators_meta_data import *
7
-
8
- from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
9
-
10
-
11
- cp_tt_nn_jax = {}
12
- cp_te_nn_jax = {}
13
- cp_ee_nn_jax = {}
14
- cp_pp_nn_jax = {}
15
- cp_pknl_nn_jax = {}
16
- cp_pkl_nn_jax = {}
17
- cp_der_nn_jax = {}
18
- cp_da_nn_jax = {}
19
- cp_h_nn_jax = {}
20
- cp_s8_nn_jax = {}
21
-
22
-
23
- for mp in cosmo_model_list:
24
- folder, version = split_emulator_string(mp)
25
- # print(folder, version)
26
- path_to_emulators = path_to_class_sz_data + '/' + folder +'/'
27
-
28
- cp_tt_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TT'])
29
-
30
- cp_te_nn_jax[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TE'])
31
-
32
- with suppress_warnings():
33
- cp_ee_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['EE'])
34
-
35
- cp_pp_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
36
-
37
- cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
38
-
39
- cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
40
-
41
- cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
42
-
43
- cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
44
-
45
- # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
46
- emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
47
- # print(emulator_custom.parameters)
48
- # exit()
49
-
50
- cp_h_nn_jax[mp] = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
51
-
52
- cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
53
-