classy-szfast 0.0.25.post1__py3-none-any.whl → 0.0.25.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -92,12 +92,14 @@ class Class_szfast(object):
92
92
  self.cp_pknl_nn = cp_pknl_nn
93
93
  self.cp_pkl_nn = cp_pkl_nn
94
94
  self.cp_der_nn = cp_der_nn
95
- self.cp_da_nn = cp_da_nn
95
+
96
96
 
97
97
  if self.jax_mode:
98
98
  self.cp_h_nn = cp_h_nn_jax
99
+ self.cp_da_nn = cp_da_nn_jax
99
100
  else:
100
101
  self.cp_h_nn = cp_h_nn
102
+ self.cp_da_nn = cp_da_nn
101
103
 
102
104
  self.cp_s8_nn = cp_s8_nn
103
105
 
@@ -741,7 +743,9 @@ class Class_szfast(object):
741
743
  # print("JAX MODE in hubble")
742
744
  # self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
743
745
  # print(params_dict)
746
+ # print("params_dict type:", type(params_dict))
744
747
  self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].predict(params_dict)
748
+ # print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
745
749
  # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
746
750
 
747
751
  # self.hz_interp = jscipy.interpolate.interp1d(
@@ -783,35 +787,50 @@ class Class_szfast(object):
783
787
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
784
788
 
785
789
  params_dict = {}
786
-
787
790
  for k,v in zip(params_values.keys(),params_values.values()):
788
-
789
791
  params_dict[k]=[v]
790
792
 
791
793
  if 'm_ncdm' in params_dict.keys():
792
794
  if isinstance(params_dict['m_ncdm'][0],str):
793
795
  params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
794
796
 
795
- # deal with different scaling of DA in different model from emulator training
796
- if self.cosmo_model == 'ede-v2':
797
+ if self.jax_mode:
798
+ # print("JAX MODE in chi")
799
+ # self.cp_da_nn[self.cosmo_model].log = False
800
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
801
+ # print("self.cp_predicted_da",self.cp_predicted_da)
802
+
803
+ if self.cosmo_model == 'ede-v2':
804
+ # print('ede-v2 case')
805
+ self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
806
+ self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
807
+
808
+ def chi_interp(x):
809
+ return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
810
+
811
+ self.chi_interp = chi_interp
797
812
 
798
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
799
- self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
800
-
801
813
  else:
802
-
803
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
804
-
814
+ # deal with different scaling of DA in different model from emulator training
815
+ if self.cosmo_model == 'ede-v2':
805
816
 
806
- self.chi_interp = scipy.interpolate.interp1d(
807
- self.cp_z_interp,
808
- self.cp_predicted_da*(1.+self.cp_z_interp),
809
- kind='linear',
810
- axis=-1,
811
- copy=True,
812
- bounds_error=None,
813
- fill_value=np.nan,
814
- assume_sorted=False)
817
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
818
+ self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
819
+
820
+ else:
821
+
822
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
823
+
824
+
825
+ self.chi_interp = scipy.interpolate.interp1d(
826
+ self.cp_z_interp,
827
+ self.cp_predicted_da*(1.+self.cp_z_interp),
828
+ kind='linear',
829
+ axis=-1,
830
+ copy=True,
831
+ bounds_error=None,
832
+ fill_value=np.nan,
833
+ assume_sorted=False)
815
834
 
816
835
  def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
817
836
 
@@ -1,5 +1,6 @@
1
1
  from .config import path_to_class_sz_data
2
2
  import numpy as np
3
+ import jax.numpy as jnp
3
4
  from .restore_nn import Restore_NN
4
5
  from .restore_nn import Restore_PCAplusNN
5
6
  from .suppress_warnings import suppress_warnings
@@ -20,6 +21,74 @@ cp_h_nn_jax = {}
20
21
  cp_s8_nn_jax = {}
21
22
 
22
23
 
24
+ class CosmoPowerJAX_custom(CPJ):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.ten_to_predictions = True
28
+ if 'ten_to_predictions' in kwargs.keys():
29
+ self.ten_to_predictions = kwargs['ten_to_predictions']
30
+
31
+ def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
32
+ feature_train_mean, feature_train_std, input_vec):
33
+ """ Forward pass through pre-trained network.
34
+ In its current form, it does not make use of high-level frameworks like
35
+ FLAX et similia; rather, it simply loops over the network layers.
36
+ In future work this can be improved, especially if speed is a problem.
37
+
38
+ Parameters
39
+ ----------
40
+ weights : array
41
+ The stored weights of the neural network.
42
+ hyper_params : array
43
+ The stored hyperparameters of the activation function for each layer.
44
+ param_train_mean : array
45
+ The stored mean of the training cosmological parameters.
46
+ param_train_std : array
47
+ The stored standard deviation of the training cosmological parameters.
48
+ feature_train_mean : array
49
+ The stored mean of the training features.
50
+ feature_train_std : array
51
+ The stored standard deviation of the training features.
52
+ input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
53
+ The cosmological parameters given as input to the network.
54
+
55
+ Returns
56
+ -------
57
+ predictions : array
58
+ The prediction of the trained neural network.
59
+ """
60
+ act = []
61
+ # Standardise
62
+ layer_out = [(input_vec - param_train_mean)/param_train_std]
63
+
64
+ # Loop over layers
65
+ for i in range(len(weights[:-1])):
66
+ w, b = weights[i]
67
+ alpha, beta = hyper_params[i]
68
+ act.append(jnp.dot(layer_out[-1], w.T) + b)
69
+ layer_out.append(self._activation(act[-1], alpha, beta))
70
+
71
+ # Final layer prediction (no activations)
72
+ w, b = weights[-1]
73
+ if self.probe == 'custom_log' or self.probe == 'custom_pca':
74
+ # in original CP models, we assumed a full final bias vector...
75
+ preds = jnp.dot(layer_out[-1], w.T) + b
76
+ else:
77
+ # ... unlike in cpjax, where we used only a single bias vector
78
+ preds = jnp.dot(layer_out[-1], w.T) + b[-1]
79
+
80
+ # Undo the standardisation
81
+ preds = preds * feature_train_std + feature_train_mean
82
+ if self.log == True:
83
+ if self.ten_to_predictions:
84
+ preds = 10**preds
85
+ else:
86
+ preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
87
+ if self.probe == 'cmb_pp':
88
+ preds = 10**preds
89
+ predictions = preds.squeeze()
90
+ return predictions
91
+
23
92
  for mp in cosmo_model_list:
24
93
  folder, version = split_emulator_string(mp)
25
94
  # print(folder, version)
@@ -40,14 +109,19 @@ for mp in cosmo_model_list:
40
109
 
41
110
  cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
42
111
 
43
- cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
112
+ # cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
113
+
44
114
 
115
+ cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
116
+ if mp != 'ede-v2':
117
+ cp_da_nn_jax[mp].ten_to_predictions = False
118
+ # print(cp_da_nn_jax[mp].parameters)
45
119
  # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
46
- emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
120
+ # emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
47
121
  # print(emulator_custom.parameters)
48
122
  # exit()
49
123
 
50
- cp_h_nn_jax[mp] = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
124
+ cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
51
125
 
52
126
  cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
53
127
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post1
3
+ Version: 0.0.25.post3
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -1,9 +1,9 @@
1
1
  classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
2
2
  classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
3
- classy_szfast/classy_szfast.py,sha256=p2N3UYK0Gmy79wX7Z1-0hoyj1LL4MKcFfUKpR7qcVt8,36493
3
+ classy_szfast/classy_szfast.py,sha256=QGXqyzMFdF8mT0RxpJxwBswvM-WRmYFgRkfzUawTmH0,37459
4
4
  classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
5
5
  classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
6
- classy_szfast/cosmopower_jax.py,sha256=C7NzfMFs9sL8rKuDdXdmwxk0UzHqNJnVjZENak-EPQA,2151
6
+ classy_szfast/cosmopower_jax.py,sha256=Bmyzl15OH4KwmpcXfWm1s4XlvlKKkDfgMy70zhnTV5I,5362
7
7
  classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
8
8
  classy_szfast/emulators_meta_data.py,sha256=-lHneGhSJ2481S48viz_bNeCyAGu1Ogee0jFEB8B618,9724
9
9
  classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
14
14
  classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
15
15
  classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
17
- classy_szfast-0.0.25.post1.dist-info/METADATA,sha256=aW7Rr2NRSpqm9RQEug4v7-p_qozJx7us_gqaTXoOIfk,548
18
- classy_szfast-0.0.25.post1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
- classy_szfast-0.0.25.post1.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
- classy_szfast-0.0.25.post1.dist-info/RECORD,,
17
+ classy_szfast-0.0.25.post3.dist-info/METADATA,sha256=JEdn1Go9uH7wuk0OWfoLzX9fRSKNsVI_fH9SF6AM9vQ,548
18
+ classy_szfast-0.0.25.post3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
+ classy_szfast-0.0.25.post3.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
+ classy_szfast-0.0.25.post3.dist-info/RECORD,,