classy-szfast 0.0.25.post2__py3-none-any.whl → 0.0.25.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -92,12 +92,14 @@ class Class_szfast(object):
92
92
  self.cp_pknl_nn = cp_pknl_nn
93
93
  self.cp_pkl_nn = cp_pkl_nn
94
94
  self.cp_der_nn = cp_der_nn
95
- self.cp_da_nn = cp_da_nn
95
+
96
96
 
97
97
  if self.jax_mode:
98
98
  self.cp_h_nn = cp_h_nn_jax
99
+ self.cp_da_nn = cp_da_nn_jax
99
100
  else:
100
101
  self.cp_h_nn = cp_h_nn
102
+ self.cp_da_nn = cp_da_nn
101
103
 
102
104
  self.cp_s8_nn = cp_s8_nn
103
105
 
@@ -785,35 +787,50 @@ class Class_szfast(object):
785
787
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
786
788
 
787
789
  params_dict = {}
788
-
789
790
  for k,v in zip(params_values.keys(),params_values.values()):
790
-
791
791
  params_dict[k]=[v]
792
792
 
793
793
  if 'm_ncdm' in params_dict.keys():
794
794
  if isinstance(params_dict['m_ncdm'][0],str):
795
795
  params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
796
796
 
797
- # deal with different scaling of DA in different model from emulator training
798
- if self.cosmo_model == 'ede-v2':
797
+ if self.jax_mode:
798
+ # print("JAX MODE in chi")
799
+ # self.cp_da_nn[self.cosmo_model].log = False
800
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
801
+ # print("self.cp_predicted_da",self.cp_predicted_da)
802
+
803
+ if self.cosmo_model == 'ede-v2':
804
+ # print('ede-v2 case')
805
+ self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
806
+ self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
807
+
808
+ def chi_interp(x):
809
+ return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
810
+
811
+ self.chi_interp = chi_interp
799
812
 
800
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
801
- self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
802
-
803
813
  else:
804
-
805
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
806
-
814
+ # deal with different scaling of DA in different model from emulator training
815
+ if self.cosmo_model == 'ede-v2':
807
816
 
808
- self.chi_interp = scipy.interpolate.interp1d(
809
- self.cp_z_interp,
810
- self.cp_predicted_da*(1.+self.cp_z_interp),
811
- kind='linear',
812
- axis=-1,
813
- copy=True,
814
- bounds_error=None,
815
- fill_value=np.nan,
816
- assume_sorted=False)
817
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
818
+ self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
819
+
820
+ else:
821
+
822
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
823
+
824
+
825
+ self.chi_interp = scipy.interpolate.interp1d(
826
+ self.cp_z_interp,
827
+ self.cp_predicted_da*(1.+self.cp_z_interp),
828
+ kind='linear',
829
+ axis=-1,
830
+ copy=True,
831
+ bounds_error=None,
832
+ fill_value=np.nan,
833
+ assume_sorted=False)
817
834
 
818
835
  def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
819
836
 
@@ -1,5 +1,6 @@
1
1
  from .config import path_to_class_sz_data
2
2
  import numpy as np
3
+ import jax.numpy as jnp
3
4
  from .restore_nn import Restore_NN
4
5
  from .restore_nn import Restore_PCAplusNN
5
6
  from .suppress_warnings import suppress_warnings
@@ -20,6 +21,74 @@ cp_h_nn_jax = {}
20
21
  cp_s8_nn_jax = {}
21
22
 
22
23
 
24
+ class CosmoPowerJAX_custom(CPJ):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.ten_to_predictions = True
28
+ if 'ten_to_predictions' in kwargs.keys():
29
+ self.ten_to_predictions = kwargs['ten_to_predictions']
30
+
31
+ def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
32
+ feature_train_mean, feature_train_std, input_vec):
33
+ """ Forward pass through pre-trained network.
34
+ In its current form, it does not make use of high-level frameworks like
35
+ FLAX et similia; rather, it simply loops over the network layers.
36
+ In future work this can be improved, especially if speed is a problem.
37
+
38
+ Parameters
39
+ ----------
40
+ weights : array
41
+ The stored weights of the neural network.
42
+ hyper_params : array
43
+ The stored hyperparameters of the activation function for each layer.
44
+ param_train_mean : array
45
+ The stored mean of the training cosmological parameters.
46
+ param_train_std : array
47
+ The stored standard deviation of the training cosmological parameters.
48
+ feature_train_mean : array
49
+ The stored mean of the training features.
50
+ feature_train_std : array
51
+ The stored standard deviation of the training features.
52
+ input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
53
+ The cosmological parameters given as input to the network.
54
+
55
+ Returns
56
+ -------
57
+ predictions : array
58
+ The prediction of the trained neural network.
59
+ """
60
+ act = []
61
+ # Standardise
62
+ layer_out = [(input_vec - param_train_mean)/param_train_std]
63
+
64
+ # Loop over layers
65
+ for i in range(len(weights[:-1])):
66
+ w, b = weights[i]
67
+ alpha, beta = hyper_params[i]
68
+ act.append(jnp.dot(layer_out[-1], w.T) + b)
69
+ layer_out.append(self._activation(act[-1], alpha, beta))
70
+
71
+ # Final layer prediction (no activations)
72
+ w, b = weights[-1]
73
+ if self.probe == 'custom_log' or self.probe == 'custom_pca':
74
+ # in original CP models, we assumed a full final bias vector...
75
+ preds = jnp.dot(layer_out[-1], w.T) + b
76
+ else:
77
+ # ... unlike in cpjax, where we used only a single bias vector
78
+ preds = jnp.dot(layer_out[-1], w.T) + b[-1]
79
+
80
+ # Undo the standardisation
81
+ preds = preds * feature_train_std + feature_train_mean
82
+ if self.log == True:
83
+ if self.ten_to_predictions:
84
+ preds = 10**preds
85
+ else:
86
+ preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
87
+ if self.probe == 'cmb_pp':
88
+ preds = 10**preds
89
+ predictions = preds.squeeze()
90
+ return predictions
91
+
23
92
  for mp in cosmo_model_list:
24
93
  folder, version = split_emulator_string(mp)
25
94
  # print(folder, version)
@@ -40,14 +109,19 @@ for mp in cosmo_model_list:
40
109
 
41
110
  cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
42
111
 
43
- cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
112
+ # cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
113
+
44
114
 
115
+ cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
116
+ if mp != 'ede-v2':
117
+ cp_da_nn_jax[mp].ten_to_predictions = False
118
+ # print(cp_da_nn_jax[mp].parameters)
45
119
  # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
46
- emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
120
+ # emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
47
121
  # print(emulator_custom.parameters)
48
122
  # exit()
49
123
 
50
- cp_h_nn_jax[mp] = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
124
+ cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
51
125
 
52
126
  cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
53
127
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post2
3
+ Version: 0.0.25.post3
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -1,9 +1,9 @@
1
1
  classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
2
2
  classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
3
- classy_szfast/classy_szfast.py,sha256=2eRVEjpTY9xkbHtIUZjZjKQJEcT_0-5ku0n63VYlXLo,36639
3
+ classy_szfast/classy_szfast.py,sha256=QGXqyzMFdF8mT0RxpJxwBswvM-WRmYFgRkfzUawTmH0,37459
4
4
  classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
5
5
  classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
6
- classy_szfast/cosmopower_jax.py,sha256=C7NzfMFs9sL8rKuDdXdmwxk0UzHqNJnVjZENak-EPQA,2151
6
+ classy_szfast/cosmopower_jax.py,sha256=Bmyzl15OH4KwmpcXfWm1s4XlvlKKkDfgMy70zhnTV5I,5362
7
7
  classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
8
8
  classy_szfast/emulators_meta_data.py,sha256=-lHneGhSJ2481S48viz_bNeCyAGu1Ogee0jFEB8B618,9724
9
9
  classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
14
14
  classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
15
15
  classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
17
- classy_szfast-0.0.25.post2.dist-info/METADATA,sha256=SHRC4v8N4uo_BSIHn9D8_KrReCq3coWsZCQIpLqaKmQ,548
18
- classy_szfast-0.0.25.post2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
- classy_szfast-0.0.25.post2.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
- classy_szfast-0.0.25.post2.dist-info/RECORD,,
17
+ classy_szfast-0.0.25.post3.dist-info/METADATA,sha256=JEdn1Go9uH7wuk0OWfoLzX9fRSKNsVI_fH9SF6AM9vQ,548
18
+ classy_szfast-0.0.25.post3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
+ classy_szfast-0.0.25.post3.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
+ classy_szfast-0.0.25.post3.dist-info/RECORD,,