classy-szfast 0.0.24__py3-none-any.whl → 0.0.25.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,9 @@
1
1
  from .utils import *
2
2
  from .config import *
3
3
  import numpy as np
4
- from .cosmopower import *
4
+ from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
5
+ from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
6
+ from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
5
7
  from .pks_and_sigmas import *
6
8
  import scipy
7
9
  import time
@@ -9,7 +11,8 @@ from multiprocessing import Process
9
11
  from mcfit import TophatVar
10
12
  from scipy.interpolate import CubicSpline
11
13
  import pickle
12
-
14
+ import jax.numpy as jnp
15
+ import jax.scipy as jscipy
13
16
 
14
17
  H_units_conv_factor = {"1/Mpc": 1, "km/s/Mpc": Const.c_km_s}
15
18
 
@@ -69,6 +72,11 @@ class Class_szfast(object):
69
72
  except:
70
73
  pass
71
74
  self.logger = logging.getLogger(__name__)
75
+
76
+
77
+ self.jax_mode = params_settings["jax"]
78
+
79
+ # print(f"JAX mode: {self.jax_mode}")
72
80
 
73
81
 
74
82
 
@@ -85,7 +93,12 @@ class Class_szfast(object):
85
93
  self.cp_pkl_nn = cp_pkl_nn
86
94
  self.cp_der_nn = cp_der_nn
87
95
  self.cp_da_nn = cp_da_nn
88
- self.cp_h_nn = cp_h_nn
96
+
97
+ if self.jax_mode:
98
+ self.cp_h_nn = cp_h_nn_jax
99
+ else:
100
+ self.cp_h_nn = cp_h_nn
101
+
89
102
  self.cp_s8_nn = cp_s8_nn
90
103
 
91
104
  self.emulator_dict = emulator_dict
@@ -203,6 +216,7 @@ class Class_szfast(object):
203
216
 
204
217
 
205
218
  self.cp_z_interp = np.linspace(0.,20.,5000)
219
+ self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
206
220
 
207
221
  self.csz_base = None
208
222
 
@@ -409,6 +423,11 @@ class Class_szfast(object):
409
423
 
410
424
  k_arr = self.cszfast_pk_grid_k
411
425
 
426
+ # print(">>> z_arr:",z_arr)
427
+ # print(">>> k_arr:",k_arr)
428
+ # import sys
429
+
430
+
412
431
 
413
432
  params_values = params_values_dict.copy()
414
433
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
@@ -445,6 +464,11 @@ class Class_szfast(object):
445
464
  params_dict_pp['z_pk_save_nonclass'] = [zp]
446
465
  predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
447
466
 
467
+ # if abs(zp-0.5) < 0.01:
468
+ # print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
469
+ # import pprint
470
+ # pprint.pprint(params_dict_pp)
471
+
448
472
  predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
449
473
 
450
474
 
@@ -453,6 +477,10 @@ class Class_szfast(object):
453
477
  pk_re = pk*self.pk_power_fac
454
478
  pk_re = np.transpose(pk_re)
455
479
 
480
+ # print(">>> pk_re:",pk_re)
481
+ # import sys
482
+ # sys.exit(0)
483
+
456
484
  self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
457
485
 
458
486
  self.cszfast_pk_grid_pk = pk_re
@@ -708,17 +736,44 @@ class Class_szfast(object):
708
736
  if isinstance(params_dict['m_ncdm'][0],str):
709
737
  params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
710
738
 
711
- self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
712
-
713
- self.hz_interp = scipy.interpolate.interp1d(
714
- self.cp_z_interp,
715
- self.cp_predicted_hubble,
716
- kind='linear',
717
- axis=-1,
718
- copy=True,
719
- bounds_error=None,
720
- fill_value=np.nan,
721
- assume_sorted=False)
739
+
740
+ if self.jax_mode:
741
+ # print("JAX MODE in hubble")
742
+ # self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
743
+ # print(params_dict)
744
+ self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].predict(params_dict)
745
+ # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
746
+
747
+ # self.hz_interp = jscipy.interpolate.interp1d(
748
+ # self.cp_z_interp_jax,
749
+ # self.cp_predicted_hubble,
750
+ # kind='linear',
751
+ # axis=-1,
752
+ # copy=True,
753
+ # bounds_error=None,
754
+ # fill_value=np.nan,
755
+ # assume_sorted=False)
756
+
757
+ # Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
758
+ def hz_interp(x):
759
+ return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
760
+
761
+ self.hz_interp = hz_interp
762
+ # exit()
763
+ else:
764
+ self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
765
+ # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
766
+
767
+
768
+ self.hz_interp = scipy.interpolate.interp1d(
769
+ self.cp_z_interp,
770
+ self.cp_predicted_hubble,
771
+ kind='linear',
772
+ axis=-1,
773
+ copy=True,
774
+ bounds_error=None,
775
+ fill_value=np.nan,
776
+ assume_sorted=False)
722
777
 
723
778
  def calculate_chi(self,
724
779
  **params_values_dict):
@@ -819,7 +874,10 @@ class Class_szfast(object):
819
874
 
820
875
 
821
876
  def get_hubble(self, z,units="1/Mpc"):
822
- return np.array(self.hz_interp(z)*H_units_conv_factor[units])
877
+ if self.jax_mode:
878
+ return jnp.array(self.hz_interp(z)*H_units_conv_factor[units])
879
+ else:
880
+ return np.array(self.hz_interp(z)*H_units_conv_factor[units])
823
881
 
824
882
  def get_chi(self, z):
825
883
  return np.array(self.chi_interp(z))
@@ -1,218 +1,9 @@
1
- from .utils import *
2
- from .config import *
3
-
1
+ from .config import path_to_class_sz_data
2
+ import numpy as np
4
3
  from .restore_nn import Restore_NN
5
4
  from .restore_nn import Restore_PCAplusNN
6
5
  from .suppress_warnings import suppress_warnings
7
-
8
- dofftlog_alphas = False
9
-
10
- cosmopower_derived_params_names = ['100*theta_s',
11
- 'sigma8',
12
- 'YHe',
13
- 'z_reio',
14
- 'Neff',
15
- 'tau_rec',
16
- 'z_rec',
17
- 'rs_rec',
18
- 'ra_rec',
19
- 'tau_star',
20
- 'z_star',
21
- 'rs_star',
22
- 'ra_star',
23
- 'rs_drag']
24
-
25
- cp_l_max_scalars = 11000 # max multipole of train ing data
26
-
27
- cosmo_model_list = [
28
- 'lcdm',
29
- 'mnu',
30
- 'neff',
31
- 'wcdm',
32
- 'ede',
33
- 'mnu-3states',
34
- 'ede-v2'
35
- ]
36
-
37
- emulator_dict = {}
38
- emulator_dict['lcdm'] = {}
39
- emulator_dict['mnu'] = {}
40
- emulator_dict['neff'] = {}
41
- emulator_dict['wcdm'] = {}
42
- emulator_dict['ede'] = {}
43
- emulator_dict['mnu-3states'] = {}
44
- emulator_dict['ede-v2'] = {}
45
-
46
- ### note on ncdm:
47
- # N_ncdm : 3
48
- # m_ncdm : 0.02, 0.02, 0.02
49
- # deg_ncdm: 1
50
- # and
51
- # N_ncdm: 1
52
- # deg_ncdm: 3
53
- # m_ncdm : 0.02
54
- # are equivalent but deg_ncdm: 3 is much faster.
55
-
56
-
57
- emulator_dict['lcdm']['TT'] = 'TT_v1'
58
- emulator_dict['lcdm']['TE'] = 'TE_v1'
59
- emulator_dict['lcdm']['EE'] = 'EE_v1'
60
- emulator_dict['lcdm']['PP'] = 'PP_v1'
61
- emulator_dict['lcdm']['PKNL'] = 'PKNL_v1'
62
- emulator_dict['lcdm']['PKL'] = 'PKL_v1'
63
- emulator_dict['lcdm']['PKLFFTLOG_ALPHAS_REAL'] = 'PKLFFTLOGALPHAS_creal_v1'
64
- emulator_dict['lcdm']['PKLFFTLOG_ALPHAS_IMAG'] = 'PKLFFTLOGALPHAS_cimag_v1'
65
- emulator_dict['lcdm']['DER'] = 'DER_v1'
66
- emulator_dict['lcdm']['DAZ'] = 'DAZ_v1'
67
- emulator_dict['lcdm']['HZ'] = 'HZ_v1'
68
- emulator_dict['lcdm']['S8Z'] = 'S8Z_v1'
69
- emulator_dict['lcdm']['default'] = {}
70
- emulator_dict['lcdm']['default']['tau_reio'] = 0.054
71
- emulator_dict['lcdm']['default']['H0'] = 67.66
72
- emulator_dict['lcdm']['default']['ln10^{10}A_s'] = 3.047
73
- emulator_dict['lcdm']['default']['omega_b'] = 0.02242
74
- emulator_dict['lcdm']['default']['omega_cdm'] = 0.11933
75
- emulator_dict['lcdm']['default']['n_s'] = 0.9665
76
- emulator_dict['lcdm']['default']['N_ur'] = 2.0328
77
- emulator_dict['lcdm']['default']['N_ncdm'] = 1
78
- emulator_dict['lcdm']['default']['m_ncdm'] = 0.06
79
-
80
- emulator_dict['mnu']['TT'] = 'TT_mnu_v1'
81
- emulator_dict['mnu']['TE'] = 'TE_mnu_v1'
82
- emulator_dict['mnu']['EE'] = 'EE_mnu_v1'
83
- emulator_dict['mnu']['PP'] = 'PP_mnu_v1'
84
- emulator_dict['mnu']['PKNL'] = 'PKNL_mnu_v1'
85
- emulator_dict['mnu']['PKL'] = 'PKL_mnu_v1'
86
- emulator_dict['mnu']['DER'] = 'DER_mnu_v1'
87
- emulator_dict['mnu']['DAZ'] = 'DAZ_mnu_v1'
88
- emulator_dict['mnu']['HZ'] = 'HZ_mnu_v1'
89
- emulator_dict['mnu']['S8Z'] = 'S8Z_mnu_v1'
90
- emulator_dict['mnu']['default'] = {}
91
- emulator_dict['mnu']['default']['tau_reio'] = 0.054
92
- emulator_dict['mnu']['default']['H0'] = 67.66
93
- emulator_dict['mnu']['default']['ln10^{10}A_s'] = 3.047
94
- emulator_dict['mnu']['default']['omega_b'] = 0.02242
95
- emulator_dict['mnu']['default']['omega_cdm'] = 0.11933
96
- emulator_dict['mnu']['default']['n_s'] = 0.9665
97
- emulator_dict['mnu']['default']['N_ur'] = 2.0328
98
- emulator_dict['mnu']['default']['N_ncdm'] = 1
99
- emulator_dict['mnu']['default']['m_ncdm'] = 0.06
100
-
101
- emulator_dict['neff']['TT'] = 'TT_neff_v1'
102
- emulator_dict['neff']['TE'] = 'TE_neff_v1'
103
- emulator_dict['neff']['EE'] = 'EE_neff_v1'
104
- emulator_dict['neff']['PP'] = 'PP_neff_v1'
105
- emulator_dict['neff']['PKNL'] = 'PKNL_neff_v1'
106
- emulator_dict['neff']['PKL'] = 'PKL_neff_v1'
107
- emulator_dict['neff']['DER'] = 'DER_neff_v1'
108
- emulator_dict['neff']['DAZ'] = 'DAZ_neff_v1'
109
- emulator_dict['neff']['HZ'] = 'HZ_neff_v1'
110
- emulator_dict['neff']['S8Z'] = 'S8Z_neff_v1'
111
- emulator_dict['neff']['default'] = {}
112
- emulator_dict['neff']['default']['tau_reio'] = 0.054
113
- emulator_dict['neff']['default']['H0'] = 67.66
114
- emulator_dict['neff']['default']['ln10^{10}A_s'] = 3.047
115
- emulator_dict['neff']['default']['omega_b'] = 0.02242
116
- emulator_dict['neff']['default']['omega_cdm'] = 0.11933
117
- emulator_dict['neff']['default']['n_s'] = 0.9665
118
- emulator_dict['neff']['default']['N_ur'] = 2.0328 # this is the default value in class v2 to get Neff = 3.046
119
- emulator_dict['neff']['default']['N_ncdm'] = 1
120
- emulator_dict['neff']['default']['m_ncdm'] = 0.06
121
-
122
-
123
- emulator_dict['wcdm']['TT'] = 'TT_w_v1'
124
- emulator_dict['wcdm']['TE'] = 'TE_w_v1'
125
- emulator_dict['wcdm']['EE'] = 'EE_w_v1'
126
- emulator_dict['wcdm']['PP'] = 'PP_w_v1'
127
- emulator_dict['wcdm']['PKNL'] = 'PKNL_w_v1'
128
- emulator_dict['wcdm']['PKL'] = 'PKL_w_v1'
129
- emulator_dict['wcdm']['DER'] = 'DER_w_v1'
130
- emulator_dict['wcdm']['DAZ'] = 'DAZ_w_v1'
131
- emulator_dict['wcdm']['HZ'] = 'HZ_w_v1'
132
- emulator_dict['wcdm']['S8Z'] = 'S8Z_w_v1'
133
- emulator_dict['wcdm']['default'] = {}
134
- emulator_dict['wcdm']['default']['tau_reio'] = 0.054
135
- emulator_dict['wcdm']['default']['H0'] = 67.66
136
- emulator_dict['wcdm']['default']['ln10^{10}A_s'] = 3.047
137
- emulator_dict['wcdm']['default']['omega_b'] = 0.02242
138
- emulator_dict['wcdm']['default']['omega_cdm'] = 0.11933
139
- emulator_dict['wcdm']['default']['n_s'] = 0.9665
140
- emulator_dict['wcdm']['default']['N_ur'] = 2.0328 # this is the default value in class v2 to get Neff = 3.046
141
- emulator_dict['wcdm']['default']['N_ncdm'] = 1
142
- emulator_dict['wcdm']['default']['m_ncdm'] = 0.06
143
-
144
- emulator_dict['ede']['TT'] = 'TT_v1'
145
- emulator_dict['ede']['TE'] = 'TE_v1'
146
- emulator_dict['ede']['EE'] = 'EE_v1'
147
- emulator_dict['ede']['PP'] = 'PP_v1'
148
- emulator_dict['ede']['PKNL'] = 'PKNL_v1'
149
- emulator_dict['ede']['PKL'] = 'PKL_v1'
150
- emulator_dict['ede']['DER'] = 'DER_v1'
151
- emulator_dict['ede']['DAZ'] = 'DAZ_v1'
152
- emulator_dict['ede']['HZ'] = 'HZ_v1'
153
- emulator_dict['ede']['S8Z'] = 'S8Z_v1'
154
- emulator_dict['ede']['default'] = {}
155
- emulator_dict['ede']['default']['fEDE'] = 0.001
156
- emulator_dict['ede']['default']['tau_reio'] = 0.054
157
- emulator_dict['ede']['default']['H0'] = 67.66
158
- emulator_dict['ede']['default']['ln10^{10}A_s'] = 3.047
159
- emulator_dict['ede']['default']['omega_b'] = 0.02242
160
- emulator_dict['ede']['default']['omega_cdm'] = 0.11933
161
- emulator_dict['ede']['default']['n_s'] = 0.9665
162
- emulator_dict['ede']['default']['log10z_c'] = 3.562 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
163
- emulator_dict['ede']['default']['thetai_scf'] = 2.83 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
164
- emulator_dict['ede']['default']['r'] = 0.
165
- emulator_dict['ede']['default']['N_ur'] = 0.00641 # this is the default value in class v2 to get Neff = 3.046
166
- emulator_dict['ede']['default']['N_ncdm'] = 3
167
- emulator_dict['ede']['default']['m_ncdm'] = 0.02
168
-
169
-
170
- emulator_dict['mnu-3states']['TT'] = 'TT_v1'
171
- emulator_dict['mnu-3states']['TE'] = 'TE_v1'
172
- emulator_dict['mnu-3states']['EE'] = 'EE_v1'
173
- emulator_dict['mnu-3states']['PP'] = 'PP_v1'
174
- emulator_dict['mnu-3states']['PKNL'] = 'PKNL_v1'
175
- emulator_dict['mnu-3states']['PKL'] = 'PKL_v1'
176
- emulator_dict['mnu-3states']['DER'] = 'DER_v1'
177
- emulator_dict['mnu-3states']['DAZ'] = 'DAZ_v1'
178
- emulator_dict['mnu-3states']['HZ'] = 'HZ_v1'
179
- emulator_dict['mnu-3states']['S8Z'] = 'S8Z_v1'
180
- emulator_dict['mnu-3states']['default'] = {}
181
- emulator_dict['mnu-3states']['default']['tau_reio'] = 0.054
182
- emulator_dict['mnu-3states']['default']['H0'] = 67.66
183
- emulator_dict['mnu-3states']['default']['ln10^{10}A_s'] = 3.047
184
- emulator_dict['mnu-3states']['default']['omega_b'] = 0.02242
185
- emulator_dict['mnu-3states']['default']['omega_cdm'] = 0.11933
186
- emulator_dict['mnu-3states']['default']['n_s'] = 0.9665
187
- emulator_dict['mnu-3states']['default']['N_ur'] = 0.00641 # this is the default value in class v2 to get Neff = 3.046
188
- emulator_dict['mnu-3states']['default']['N_ncdm'] = 3
189
- emulator_dict['mnu-3states']['default']['m_ncdm'] = 0.02
190
-
191
- emulator_dict['ede-v2']['TT'] = 'TT_v2'
192
- emulator_dict['ede-v2']['TE'] = 'TE_v2'
193
- emulator_dict['ede-v2']['EE'] = 'EE_v2'
194
- emulator_dict['ede-v2']['PP'] = 'PP_v2'
195
- emulator_dict['ede-v2']['PKNL'] = 'PKNL_v2'
196
- emulator_dict['ede-v2']['PKL'] = 'PKL_v2'
197
- emulator_dict['ede-v2']['DER'] = 'DER_v2'
198
- emulator_dict['ede-v2']['DAZ'] = 'DAZ_v2'
199
- emulator_dict['ede-v2']['HZ'] = 'HZ_v2'
200
- emulator_dict['ede-v2']['S8Z'] = 'S8Z_v2'
201
-
202
- emulator_dict['ede-v2']['default'] = {}
203
- emulator_dict['ede-v2']['default']['fEDE'] = 0.001
204
- emulator_dict['ede-v2']['default']['tau_reio'] = 0.054
205
- emulator_dict['ede-v2']['default']['H0'] = 67.66
206
- emulator_dict['ede-v2']['default']['ln10^{10}A_s'] = 3.047
207
- emulator_dict['ede-v2']['default']['omega_b'] = 0.02242
208
- emulator_dict['ede-v2']['default']['omega_cdm'] = 0.11933
209
- emulator_dict['ede-v2']['default']['n_s'] = 0.9665
210
- emulator_dict['ede-v2']['default']['log10z_c'] = 3.562 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
211
- emulator_dict['ede-v2']['default']['thetai_scf'] = 2.83 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
212
- emulator_dict['ede-v2']['default']['r'] = 0.
213
- emulator_dict['ede-v2']['default']['N_ur'] = 0.00441 # this is the default value in class v3 to get Neff = 3.044
214
- emulator_dict['ede-v2']['default']['N_ncdm'] = 3
215
- emulator_dict['ede-v2']['default']['m_ncdm'] = 0.02
6
+ from .emulators_meta_data import *
216
7
 
217
8
 
218
9
  cp_tt_nn = {}
@@ -229,36 +20,6 @@ cp_da_nn = {}
229
20
  cp_h_nn = {}
230
21
  cp_s8_nn = {}
231
22
 
232
- import warnings
233
- from contextlib import contextmanager
234
- import logging
235
-
236
- # Suppress absl warnings
237
- import absl.logging
238
- absl.logging.set_verbosity('error')
239
- # Suppress TensorFlow warnings
240
- import os
241
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
242
- with suppress_warnings():
243
- import tensorflow as tf
244
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
245
-
246
-
247
- import re
248
-
249
- def split_emulator_string(input_string):
250
- match = re.match(r"(.+)-v(\d+)", input_string)
251
- if match:
252
- folder = match.group(1)
253
- version = match.group(2)
254
- return folder, version
255
- else:
256
- folder = input_string
257
- version = '1'
258
- return folder, version
259
-
260
-
261
-
262
23
 
263
24
  for mp in cosmo_model_list:
264
25
  folder, version = split_emulator_string(mp)
@@ -0,0 +1,53 @@
1
+ from .config import path_to_class_sz_data
2
+ import numpy as np
3
+ from .restore_nn import Restore_NN
4
+ from .restore_nn import Restore_PCAplusNN
5
+ from .suppress_warnings import suppress_warnings
6
+ from .emulators_meta_data import *
7
+
8
+ from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
9
+
10
+
11
+ cp_tt_nn_jax = {}
12
+ cp_te_nn_jax = {}
13
+ cp_ee_nn_jax = {}
14
+ cp_pp_nn_jax = {}
15
+ cp_pknl_nn_jax = {}
16
+ cp_pkl_nn_jax = {}
17
+ cp_der_nn_jax = {}
18
+ cp_da_nn_jax = {}
19
+ cp_h_nn_jax = {}
20
+ cp_s8_nn_jax = {}
21
+
22
+
23
+ for mp in cosmo_model_list:
24
+ folder, version = split_emulator_string(mp)
25
+ # print(folder, version)
26
+ path_to_emulators = path_to_class_sz_data + '/' + folder +'/'
27
+
28
+ cp_tt_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TT'])
29
+
30
+ cp_te_nn_jax[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TE'])
31
+
32
+ with suppress_warnings():
33
+ cp_ee_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['EE'])
34
+
35
+ cp_pp_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
36
+
37
+ cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
38
+
39
+ cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
40
+
41
+ cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
42
+
43
+ cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
44
+
45
+ # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
46
+ emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
47
+ # print(emulator_custom.parameters)
48
+ # exit()
49
+
50
+ cp_h_nn_jax[mp] = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
51
+
52
+ cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
53
+
@@ -0,0 +1,242 @@
1
+ from .suppress_warnings import suppress_warnings
2
+ import warnings
3
+ from contextlib import contextmanager
4
+ import logging
5
+
6
+ # Suppress absl warnings
7
+ import absl.logging
8
+ absl.logging.set_verbosity('error')
9
+ # Suppress TensorFlow warnings
10
+ import os
11
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
12
+ with suppress_warnings():
13
+ import tensorflow as tf
14
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
15
+
16
+
17
+ import re
18
+
19
+ dofftlog_alphas = False
20
+
21
+ cosmopower_derived_params_names = ['100*theta_s',
22
+ 'sigma8',
23
+ 'YHe',
24
+ 'z_reio',
25
+ 'Neff',
26
+ 'tau_rec',
27
+ 'z_rec',
28
+ 'rs_rec',
29
+ 'ra_rec',
30
+ 'tau_star',
31
+ 'z_star',
32
+ 'rs_star',
33
+ 'ra_star',
34
+ 'rs_drag']
35
+
36
+ cp_l_max_scalars = 11000 # max multipole of train ing data
37
+
38
+ cosmo_model_list = [
39
+ 'lcdm',
40
+ 'mnu',
41
+ 'neff',
42
+ 'wcdm',
43
+ 'ede',
44
+ 'mnu-3states',
45
+ 'ede-v2'
46
+ ]
47
+
48
+ emulator_dict = {}
49
+ emulator_dict['lcdm'] = {}
50
+ emulator_dict['mnu'] = {}
51
+ emulator_dict['neff'] = {}
52
+ emulator_dict['wcdm'] = {}
53
+ emulator_dict['ede'] = {}
54
+ emulator_dict['mnu-3states'] = {}
55
+ emulator_dict['ede-v2'] = {}
56
+
57
+ ### note on ncdm:
58
+ # N_ncdm : 3
59
+ # m_ncdm : 0.02, 0.02, 0.02
60
+ # deg_ncdm: 1
61
+ # and
62
+ # N_ncdm: 1
63
+ # deg_ncdm: 3
64
+ # m_ncdm : 0.02
65
+ # are equivalent but deg_ncdm: 3 is much faster.
66
+
67
+
68
+ emulator_dict['lcdm']['TT'] = 'TT_v1'
69
+ emulator_dict['lcdm']['TE'] = 'TE_v1'
70
+ emulator_dict['lcdm']['EE'] = 'EE_v1'
71
+ emulator_dict['lcdm']['PP'] = 'PP_v1'
72
+ emulator_dict['lcdm']['PKNL'] = 'PKNL_v1'
73
+ emulator_dict['lcdm']['PKL'] = 'PKL_v1'
74
+ emulator_dict['lcdm']['PKLFFTLOG_ALPHAS_REAL'] = 'PKLFFTLOGALPHAS_creal_v1'
75
+ emulator_dict['lcdm']['PKLFFTLOG_ALPHAS_IMAG'] = 'PKLFFTLOGALPHAS_cimag_v1'
76
+ emulator_dict['lcdm']['DER'] = 'DER_v1'
77
+ emulator_dict['lcdm']['DAZ'] = 'DAZ_v1'
78
+ emulator_dict['lcdm']['HZ'] = 'HZ_v1'
79
+ emulator_dict['lcdm']['S8Z'] = 'S8Z_v1'
80
+ emulator_dict['lcdm']['default'] = {}
81
+ emulator_dict['lcdm']['default']['tau_reio'] = 0.054
82
+ emulator_dict['lcdm']['default']['H0'] = 67.66
83
+ emulator_dict['lcdm']['default']['ln10^{10}A_s'] = 3.047
84
+ emulator_dict['lcdm']['default']['omega_b'] = 0.02242
85
+ emulator_dict['lcdm']['default']['omega_cdm'] = 0.11933
86
+ emulator_dict['lcdm']['default']['n_s'] = 0.9665
87
+ emulator_dict['lcdm']['default']['N_ur'] = 2.0328
88
+ emulator_dict['lcdm']['default']['N_ncdm'] = 1
89
+ emulator_dict['lcdm']['default']['m_ncdm'] = 0.06
90
+
91
+ emulator_dict['mnu']['TT'] = 'TT_mnu_v1'
92
+ emulator_dict['mnu']['TE'] = 'TE_mnu_v1'
93
+ emulator_dict['mnu']['EE'] = 'EE_mnu_v1'
94
+ emulator_dict['mnu']['PP'] = 'PP_mnu_v1'
95
+ emulator_dict['mnu']['PKNL'] = 'PKNL_mnu_v1'
96
+ emulator_dict['mnu']['PKL'] = 'PKL_mnu_v1'
97
+ emulator_dict['mnu']['DER'] = 'DER_mnu_v1'
98
+ emulator_dict['mnu']['DAZ'] = 'DAZ_mnu_v1'
99
+ emulator_dict['mnu']['HZ'] = 'HZ_mnu_v1'
100
+ emulator_dict['mnu']['S8Z'] = 'S8Z_mnu_v1'
101
+ emulator_dict['mnu']['default'] = {}
102
+ emulator_dict['mnu']['default']['tau_reio'] = 0.054
103
+ emulator_dict['mnu']['default']['H0'] = 67.66
104
+ emulator_dict['mnu']['default']['ln10^{10}A_s'] = 3.047
105
+ emulator_dict['mnu']['default']['omega_b'] = 0.02242
106
+ emulator_dict['mnu']['default']['omega_cdm'] = 0.11933
107
+ emulator_dict['mnu']['default']['n_s'] = 0.9665
108
+ emulator_dict['mnu']['default']['N_ur'] = 2.0328
109
+ emulator_dict['mnu']['default']['N_ncdm'] = 1
110
+ emulator_dict['mnu']['default']['m_ncdm'] = 0.06
111
+
112
+ emulator_dict['neff']['TT'] = 'TT_neff_v1'
113
+ emulator_dict['neff']['TE'] = 'TE_neff_v1'
114
+ emulator_dict['neff']['EE'] = 'EE_neff_v1'
115
+ emulator_dict['neff']['PP'] = 'PP_neff_v1'
116
+ emulator_dict['neff']['PKNL'] = 'PKNL_neff_v1'
117
+ emulator_dict['neff']['PKL'] = 'PKL_neff_v1'
118
+ emulator_dict['neff']['DER'] = 'DER_neff_v1'
119
+ emulator_dict['neff']['DAZ'] = 'DAZ_neff_v1'
120
+ emulator_dict['neff']['HZ'] = 'HZ_neff_v1'
121
+ emulator_dict['neff']['S8Z'] = 'S8Z_neff_v1'
122
+ emulator_dict['neff']['default'] = {}
123
+ emulator_dict['neff']['default']['tau_reio'] = 0.054
124
+ emulator_dict['neff']['default']['H0'] = 67.66
125
+ emulator_dict['neff']['default']['ln10^{10}A_s'] = 3.047
126
+ emulator_dict['neff']['default']['omega_b'] = 0.02242
127
+ emulator_dict['neff']['default']['omega_cdm'] = 0.11933
128
+ emulator_dict['neff']['default']['n_s'] = 0.9665
129
+ emulator_dict['neff']['default']['N_ur'] = 2.0328 # this is the default value in class v2 to get Neff = 3.046
130
+ emulator_dict['neff']['default']['N_ncdm'] = 1
131
+ emulator_dict['neff']['default']['m_ncdm'] = 0.06
132
+
133
+
134
+ emulator_dict['wcdm']['TT'] = 'TT_w_v1'
135
+ emulator_dict['wcdm']['TE'] = 'TE_w_v1'
136
+ emulator_dict['wcdm']['EE'] = 'EE_w_v1'
137
+ emulator_dict['wcdm']['PP'] = 'PP_w_v1'
138
+ emulator_dict['wcdm']['PKNL'] = 'PKNL_w_v1'
139
+ emulator_dict['wcdm']['PKL'] = 'PKL_w_v1'
140
+ emulator_dict['wcdm']['DER'] = 'DER_w_v1'
141
+ emulator_dict['wcdm']['DAZ'] = 'DAZ_w_v1'
142
+ emulator_dict['wcdm']['HZ'] = 'HZ_w_v1'
143
+ emulator_dict['wcdm']['S8Z'] = 'S8Z_w_v1'
144
+ emulator_dict['wcdm']['default'] = {}
145
+ emulator_dict['wcdm']['default']['tau_reio'] = 0.054
146
+ emulator_dict['wcdm']['default']['H0'] = 67.66
147
+ emulator_dict['wcdm']['default']['ln10^{10}A_s'] = 3.047
148
+ emulator_dict['wcdm']['default']['omega_b'] = 0.02242
149
+ emulator_dict['wcdm']['default']['omega_cdm'] = 0.11933
150
+ emulator_dict['wcdm']['default']['n_s'] = 0.9665
151
+ emulator_dict['wcdm']['default']['N_ur'] = 2.0328 # this is the default value in class v2 to get Neff = 3.046
152
+ emulator_dict['wcdm']['default']['N_ncdm'] = 1
153
+ emulator_dict['wcdm']['default']['m_ncdm'] = 0.06
154
+
155
+ emulator_dict['ede']['TT'] = 'TT_v1'
156
+ emulator_dict['ede']['TE'] = 'TE_v1'
157
+ emulator_dict['ede']['EE'] = 'EE_v1'
158
+ emulator_dict['ede']['PP'] = 'PP_v1'
159
+ emulator_dict['ede']['PKNL'] = 'PKNL_v1'
160
+ emulator_dict['ede']['PKL'] = 'PKL_v1'
161
+ emulator_dict['ede']['DER'] = 'DER_v1'
162
+ emulator_dict['ede']['DAZ'] = 'DAZ_v1'
163
+ emulator_dict['ede']['HZ'] = 'HZ_v1'
164
+ emulator_dict['ede']['S8Z'] = 'S8Z_v1'
165
+ emulator_dict['ede']['default'] = {}
166
+ emulator_dict['ede']['default']['fEDE'] = 0.001
167
+ emulator_dict['ede']['default']['tau_reio'] = 0.054
168
+ emulator_dict['ede']['default']['H0'] = 67.66
169
+ emulator_dict['ede']['default']['ln10^{10}A_s'] = 3.047
170
+ emulator_dict['ede']['default']['omega_b'] = 0.02242
171
+ emulator_dict['ede']['default']['omega_cdm'] = 0.11933
172
+ emulator_dict['ede']['default']['n_s'] = 0.9665
173
+ emulator_dict['ede']['default']['log10z_c'] = 3.562 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
174
+ emulator_dict['ede']['default']['thetai_scf'] = 2.83 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
175
+ emulator_dict['ede']['default']['r'] = 0.
176
+ emulator_dict['ede']['default']['N_ur'] = 0.00641 # this is the default value in class v2 to get Neff = 3.046
177
+ emulator_dict['ede']['default']['N_ncdm'] = 3
178
+ emulator_dict['ede']['default']['m_ncdm'] = 0.02
179
+
180
+
181
+ emulator_dict['mnu-3states']['TT'] = 'TT_v1'
182
+ emulator_dict['mnu-3states']['TE'] = 'TE_v1'
183
+ emulator_dict['mnu-3states']['EE'] = 'EE_v1'
184
+ emulator_dict['mnu-3states']['PP'] = 'PP_v1'
185
+ emulator_dict['mnu-3states']['PKNL'] = 'PKNL_v1'
186
+ emulator_dict['mnu-3states']['PKL'] = 'PKL_v1'
187
+ emulator_dict['mnu-3states']['DER'] = 'DER_v1'
188
+ emulator_dict['mnu-3states']['DAZ'] = 'DAZ_v1'
189
+ emulator_dict['mnu-3states']['HZ'] = 'HZ_v1'
190
+ emulator_dict['mnu-3states']['S8Z'] = 'S8Z_v1'
191
+ emulator_dict['mnu-3states']['default'] = {}
192
+ emulator_dict['mnu-3states']['default']['tau_reio'] = 0.054
193
+ emulator_dict['mnu-3states']['default']['H0'] = 67.66
194
+ emulator_dict['mnu-3states']['default']['ln10^{10}A_s'] = 3.047
195
+ emulator_dict['mnu-3states']['default']['omega_b'] = 0.02242
196
+ emulator_dict['mnu-3states']['default']['omega_cdm'] = 0.11933
197
+ emulator_dict['mnu-3states']['default']['n_s'] = 0.9665
198
+ emulator_dict['mnu-3states']['default']['N_ur'] = 0.00641 # this is the default value in class v2 to get Neff = 3.046
199
+ emulator_dict['mnu-3states']['default']['N_ncdm'] = 3
200
+ emulator_dict['mnu-3states']['default']['m_ncdm'] = 0.02
201
+
202
+ emulator_dict['ede-v2']['TT'] = 'TT_v2'
203
+ emulator_dict['ede-v2']['TE'] = 'TE_v2'
204
+ emulator_dict['ede-v2']['EE'] = 'EE_v2'
205
+ emulator_dict['ede-v2']['PP'] = 'PP_v2'
206
+ emulator_dict['ede-v2']['PKNL'] = 'PKNL_v2'
207
+ emulator_dict['ede-v2']['PKL'] = 'PKL_v2'
208
+ emulator_dict['ede-v2']['DER'] = 'DER_v2'
209
+ emulator_dict['ede-v2']['DAZ'] = 'DAZ_v2'
210
+ emulator_dict['ede-v2']['HZ'] = 'HZ_v2'
211
+ emulator_dict['ede-v2']['S8Z'] = 'S8Z_v2'
212
+
213
+ emulator_dict['ede-v2']['default'] = {}
214
+ emulator_dict['ede-v2']['default']['fEDE'] = 0.001
215
+ emulator_dict['ede-v2']['default']['tau_reio'] = 0.054
216
+ emulator_dict['ede-v2']['default']['H0'] = 67.66
217
+ emulator_dict['ede-v2']['default']['ln10^{10}A_s'] = 3.047
218
+ emulator_dict['ede-v2']['default']['omega_b'] = 0.02242
219
+ emulator_dict['ede-v2']['default']['omega_cdm'] = 0.11933
220
+ emulator_dict['ede-v2']['default']['n_s'] = 0.9665
221
+ emulator_dict['ede-v2']['default']['log10z_c'] = 3.562 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
222
+ emulator_dict['ede-v2']['default']['thetai_scf'] = 2.83 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
223
+ emulator_dict['ede-v2']['default']['r'] = 0.
224
+ emulator_dict['ede-v2']['default']['N_ur'] = 0.00441 # this is the default value in class v3 to get Neff = 3.044
225
+ emulator_dict['ede-v2']['default']['N_ncdm'] = 3
226
+ emulator_dict['ede-v2']['default']['m_ncdm'] = 0.02
227
+
228
+
229
+
230
+
231
+ def split_emulator_string(input_string):
232
+ match = re.match(r"(.+)-v(\d+)", input_string)
233
+ if match:
234
+ folder = match.group(1)
235
+ version = match.group(2)
236
+ return folder, version
237
+ else:
238
+ folder = input_string
239
+ version = '1'
240
+ return folder, version
241
+
242
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.24
3
+ Version: 0.0.25.post1
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -13,4 +13,5 @@ Requires-Dist: tensorflow
13
13
  Requires-Dist: mcfit
14
14
  Requires-Dist: get-cosmopower-emus
15
15
  Requires-Dist: class-sz-data
16
+ Requires-Dist: cosmopower-jax
16
17
 
@@ -1,9 +1,11 @@
1
1
  classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
2
2
  classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
3
- classy_szfast/classy_szfast.py,sha256=___sH2Xx-8AdK2bv-HsFwvZodjnlx5IX02If88Tpj5A,33800
3
+ classy_szfast/classy_szfast.py,sha256=p2N3UYK0Gmy79wX7Z1-0hoyj1LL4MKcFfUKpR7qcVt8,36493
4
4
  classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
5
- classy_szfast/cosmopower.py,sha256=qM6b3myW84_o_sOrdtKLui3DhOwxMnJA9cnmXDBE0Ps,11972
5
+ classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
6
+ classy_szfast/cosmopower_jax.py,sha256=C7NzfMFs9sL8rKuDdXdmwxk0UzHqNJnVjZENak-EPQA,2151
6
7
  classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
8
+ classy_szfast/emulators_meta_data.py,sha256=-lHneGhSJ2481S48viz_bNeCyAGu1Ogee0jFEB8B618,9724
7
9
  classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
8
10
  classy_szfast/restore_nn.py,sha256=DqA9thhTRiGBDVb9zjhqcbF2W4V0AU0vrjJFhnLboU4,21075
9
11
  classy_szfast/suppress_warnings.py,sha256=6wIBml2Sj9DyRGZlZWhuA9hqvpxqrNyYjuz6BPK_a6E,202
@@ -12,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
12
14
  classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
13
15
  classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
16
  classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
15
- classy_szfast-0.0.24.dist-info/METADATA,sha256=I8_DtLaWDTMH9my7frjNAGWpPsmQpIX0qxtkk60v-zI,512
16
- classy_szfast-0.0.24.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
17
- classy_szfast-0.0.24.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
18
- classy_szfast-0.0.24.dist-info/RECORD,,
17
+ classy_szfast-0.0.25.post1.dist-info/METADATA,sha256=aW7Rr2NRSpqm9RQEug4v7-p_qozJx7us_gqaTXoOIfk,548
18
+ classy_szfast-0.0.25.post1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
+ classy_szfast-0.0.25.post1.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
+ classy_szfast-0.0.25.post1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5