classy-szfast 0.0.24__tar.gz → 0.0.25.post1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/PKG-INFO +2 -1
  2. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/classy_szfast.py +73 -15
  3. classy_szfast-0.0.25.post1/classy_szfast/cosmopower.py +59 -0
  4. classy_szfast-0.0.25.post1/classy_szfast/cosmopower_jax.py +53 -0
  5. classy_szfast-0.0.24/classy_szfast/cosmopower.py → classy_szfast-0.0.25.post1/classy_szfast/emulators_meta_data.py +16 -72
  6. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast.egg-info/PKG-INFO +2 -1
  7. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast.egg-info/SOURCES.txt +2 -0
  8. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast.egg-info/requires.txt +1 -0
  9. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/pyproject.toml +3 -2
  10. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/README.md +0 -0
  11. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/__init__.py +0 -0
  12. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/classy_sz.py +0 -0
  13. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/config.py +0 -0
  14. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  15. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/custom_bias/__init__.py +0 -0
  16. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/custom_bias/custom_bias.py +0 -0
  17. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/custom_profiles/__init__.py +0 -0
  18. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  19. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/pks_and_sigmas.py +0 -0
  20. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/restore_nn.py +0 -0
  21. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/suppress_warnings.py +0 -0
  22. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast/utils.py +0 -0
  23. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast.egg-info/dependency_links.txt +0 -0
  24. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/classy_szfast.egg-info/top_level.txt +0 -0
  25. {classy_szfast-0.0.24 → classy_szfast-0.0.25.post1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.24
3
+ Version: 0.0.25.post1
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -13,3 +13,4 @@ Requires-Dist: tensorflow
13
13
  Requires-Dist: mcfit
14
14
  Requires-Dist: get_cosmopower_emus
15
15
  Requires-Dist: class_sz_data
16
+ Requires-Dist: cosmopower-jax
@@ -1,7 +1,9 @@
1
1
  from .utils import *
2
2
  from .config import *
3
3
  import numpy as np
4
- from .cosmopower import *
4
+ from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
5
+ from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
6
+ from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
5
7
  from .pks_and_sigmas import *
6
8
  import scipy
7
9
  import time
@@ -9,7 +11,8 @@ from multiprocessing import Process
9
11
  from mcfit import TophatVar
10
12
  from scipy.interpolate import CubicSpline
11
13
  import pickle
12
-
14
+ import jax.numpy as jnp
15
+ import jax.scipy as jscipy
13
16
 
14
17
  H_units_conv_factor = {"1/Mpc": 1, "km/s/Mpc": Const.c_km_s}
15
18
 
@@ -69,6 +72,11 @@ class Class_szfast(object):
69
72
  except:
70
73
  pass
71
74
  self.logger = logging.getLogger(__name__)
75
+
76
+
77
+ self.jax_mode = params_settings["jax"]
78
+
79
+ # print(f"JAX mode: {self.jax_mode}")
72
80
 
73
81
 
74
82
 
@@ -85,7 +93,12 @@ class Class_szfast(object):
85
93
  self.cp_pkl_nn = cp_pkl_nn
86
94
  self.cp_der_nn = cp_der_nn
87
95
  self.cp_da_nn = cp_da_nn
88
- self.cp_h_nn = cp_h_nn
96
+
97
+ if self.jax_mode:
98
+ self.cp_h_nn = cp_h_nn_jax
99
+ else:
100
+ self.cp_h_nn = cp_h_nn
101
+
89
102
  self.cp_s8_nn = cp_s8_nn
90
103
 
91
104
  self.emulator_dict = emulator_dict
@@ -203,6 +216,7 @@ class Class_szfast(object):
203
216
 
204
217
 
205
218
  self.cp_z_interp = np.linspace(0.,20.,5000)
219
+ self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
206
220
 
207
221
  self.csz_base = None
208
222
 
@@ -409,6 +423,11 @@ class Class_szfast(object):
409
423
 
410
424
  k_arr = self.cszfast_pk_grid_k
411
425
 
426
+ # print(">>> z_arr:",z_arr)
427
+ # print(">>> k_arr:",k_arr)
428
+ # import sys
429
+
430
+
412
431
 
413
432
  params_values = params_values_dict.copy()
414
433
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
@@ -445,6 +464,11 @@ class Class_szfast(object):
445
464
  params_dict_pp['z_pk_save_nonclass'] = [zp]
446
465
  predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
447
466
 
467
+ # if abs(zp-0.5) < 0.01:
468
+ # print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
469
+ # import pprint
470
+ # pprint.pprint(params_dict_pp)
471
+
448
472
  predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
449
473
 
450
474
 
@@ -453,6 +477,10 @@ class Class_szfast(object):
453
477
  pk_re = pk*self.pk_power_fac
454
478
  pk_re = np.transpose(pk_re)
455
479
 
480
+ # print(">>> pk_re:",pk_re)
481
+ # import sys
482
+ # sys.exit(0)
483
+
456
484
  self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
457
485
 
458
486
  self.cszfast_pk_grid_pk = pk_re
@@ -708,17 +736,44 @@ class Class_szfast(object):
708
736
  if isinstance(params_dict['m_ncdm'][0],str):
709
737
  params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
710
738
 
711
- self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
712
-
713
- self.hz_interp = scipy.interpolate.interp1d(
714
- self.cp_z_interp,
715
- self.cp_predicted_hubble,
716
- kind='linear',
717
- axis=-1,
718
- copy=True,
719
- bounds_error=None,
720
- fill_value=np.nan,
721
- assume_sorted=False)
739
+
740
+ if self.jax_mode:
741
+ # print("JAX MODE in hubble")
742
+ # self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
743
+ # print(params_dict)
744
+ self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].predict(params_dict)
745
+ # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
746
+
747
+ # self.hz_interp = jscipy.interpolate.interp1d(
748
+ # self.cp_z_interp_jax,
749
+ # self.cp_predicted_hubble,
750
+ # kind='linear',
751
+ # axis=-1,
752
+ # copy=True,
753
+ # bounds_error=None,
754
+ # fill_value=np.nan,
755
+ # assume_sorted=False)
756
+
757
+ # Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
758
+ def hz_interp(x):
759
+ return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
760
+
761
+ self.hz_interp = hz_interp
762
+ # exit()
763
+ else:
764
+ self.cp_predicted_hubble = self.cp_h_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
765
+ # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
766
+
767
+
768
+ self.hz_interp = scipy.interpolate.interp1d(
769
+ self.cp_z_interp,
770
+ self.cp_predicted_hubble,
771
+ kind='linear',
772
+ axis=-1,
773
+ copy=True,
774
+ bounds_error=None,
775
+ fill_value=np.nan,
776
+ assume_sorted=False)
722
777
 
723
778
  def calculate_chi(self,
724
779
  **params_values_dict):
@@ -819,7 +874,10 @@ class Class_szfast(object):
819
874
 
820
875
 
821
876
  def get_hubble(self, z,units="1/Mpc"):
822
- return np.array(self.hz_interp(z)*H_units_conv_factor[units])
877
+ if self.jax_mode:
878
+ return jnp.array(self.hz_interp(z)*H_units_conv_factor[units])
879
+ else:
880
+ return np.array(self.hz_interp(z)*H_units_conv_factor[units])
823
881
 
824
882
  def get_chi(self, z):
825
883
  return np.array(self.chi_interp(z))
@@ -0,0 +1,59 @@
1
+ from .config import path_to_class_sz_data
2
+ import numpy as np
3
+ from .restore_nn import Restore_NN
4
+ from .restore_nn import Restore_PCAplusNN
5
+ from .suppress_warnings import suppress_warnings
6
+ from .emulators_meta_data import *
7
+
8
+
9
+ cp_tt_nn = {}
10
+ cp_te_nn = {}
11
+ cp_ee_nn = {}
12
+ cp_pp_nn = {}
13
+ cp_pknl_nn = {}
14
+ cp_pkl_nn = {}
15
+ cp_pkl_fftlog_alphas_real_nn = {}
16
+ cp_pkl_fftlog_alphas_imag_nn = {}
17
+ cp_pkl_fftlog_alphas_nus = {}
18
+ cp_der_nn = {}
19
+ cp_da_nn = {}
20
+ cp_h_nn = {}
21
+ cp_s8_nn = {}
22
+
23
+
24
+ for mp in cosmo_model_list:
25
+ folder, version = split_emulator_string(mp)
26
+ # print(folder, version)
27
+ path_to_emulators = path_to_class_sz_data + '/' + folder +'/'
28
+
29
+ cp_tt_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TT'])
30
+
31
+ cp_te_nn[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TE'])
32
+
33
+ with suppress_warnings():
34
+ cp_ee_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['EE'])
35
+
36
+ cp_pp_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
37
+
38
+ cp_pknl_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
39
+
40
+ cp_pkl_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
41
+
42
+ if (mp == 'lcdm') and (dofftlog_alphas == True):
43
+ cp_pkl_fftlog_alphas_real_nn[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKLFFTLOG_ALPHAS_REAL']
44
+ )
45
+ cp_pkl_fftlog_alphas_imag_nn[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKLFFTLOG_ALPHAS_IMAG']
46
+ )
47
+ cp_pkl_fftlog_alphas_nus[mp] = np.load(path_to_emulators + 'PK/PKL_FFTLog_alphas_nu_v1.npz')
48
+
49
+ cp_der_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
50
+
51
+ cp_da_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
52
+
53
+ cp_h_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
54
+
55
+ cp_s8_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
56
+
57
+
58
+
59
+
@@ -0,0 +1,53 @@
1
+ from .config import path_to_class_sz_data
2
+ import numpy as np
3
+ from .restore_nn import Restore_NN
4
+ from .restore_nn import Restore_PCAplusNN
5
+ from .suppress_warnings import suppress_warnings
6
+ from .emulators_meta_data import *
7
+
8
+ from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
9
+
10
+
11
+ cp_tt_nn_jax = {}
12
+ cp_te_nn_jax = {}
13
+ cp_ee_nn_jax = {}
14
+ cp_pp_nn_jax = {}
15
+ cp_pknl_nn_jax = {}
16
+ cp_pkl_nn_jax = {}
17
+ cp_der_nn_jax = {}
18
+ cp_da_nn_jax = {}
19
+ cp_h_nn_jax = {}
20
+ cp_s8_nn_jax = {}
21
+
22
+
23
+ for mp in cosmo_model_list:
24
+ folder, version = split_emulator_string(mp)
25
+ # print(folder, version)
26
+ path_to_emulators = path_to_class_sz_data + '/' + folder +'/'
27
+
28
+ cp_tt_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TT'])
29
+
30
+ cp_te_nn_jax[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TE'])
31
+
32
+ with suppress_warnings():
33
+ cp_ee_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['EE'])
34
+
35
+ cp_pp_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
36
+
37
+ cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
38
+
39
+ cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
40
+
41
+ cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
42
+
43
+ cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
44
+
45
+ # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
46
+ emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
47
+ # print(emulator_custom.parameters)
48
+ # exit()
49
+
50
+ cp_h_nn_jax[mp] = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
51
+
52
+ cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
53
+
@@ -1,9 +1,20 @@
1
- from .utils import *
2
- from .config import *
3
-
4
- from .restore_nn import Restore_NN
5
- from .restore_nn import Restore_PCAplusNN
6
1
  from .suppress_warnings import suppress_warnings
2
+ import warnings
3
+ from contextlib import contextmanager
4
+ import logging
5
+
6
+ # Suppress absl warnings
7
+ import absl.logging
8
+ absl.logging.set_verbosity('error')
9
+ # Suppress TensorFlow warnings
10
+ import os
11
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
12
+ with suppress_warnings():
13
+ import tensorflow as tf
14
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
15
+
16
+
17
+ import re
7
18
 
8
19
  dofftlog_alphas = False
9
20
 
@@ -215,36 +226,7 @@ emulator_dict['ede-v2']['default']['N_ncdm'] = 3
215
226
  emulator_dict['ede-v2']['default']['m_ncdm'] = 0.02
216
227
 
217
228
 
218
- cp_tt_nn = {}
219
- cp_te_nn = {}
220
- cp_ee_nn = {}
221
- cp_pp_nn = {}
222
- cp_pknl_nn = {}
223
- cp_pkl_nn = {}
224
- cp_pkl_fftlog_alphas_real_nn = {}
225
- cp_pkl_fftlog_alphas_imag_nn = {}
226
- cp_pkl_fftlog_alphas_nus = {}
227
- cp_der_nn = {}
228
- cp_da_nn = {}
229
- cp_h_nn = {}
230
- cp_s8_nn = {}
231
229
 
232
- import warnings
233
- from contextlib import contextmanager
234
- import logging
235
-
236
- # Suppress absl warnings
237
- import absl.logging
238
- absl.logging.set_verbosity('error')
239
- # Suppress TensorFlow warnings
240
- import os
241
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
242
- with suppress_warnings():
243
- import tensorflow as tf
244
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
245
-
246
-
247
- import re
248
230
 
249
231
  def split_emulator_string(input_string):
250
232
  match = re.match(r"(.+)-v(\d+)", input_string)
@@ -258,41 +240,3 @@ def split_emulator_string(input_string):
258
240
  return folder, version
259
241
 
260
242
 
261
-
262
-
263
- for mp in cosmo_model_list:
264
- folder, version = split_emulator_string(mp)
265
- # print(folder, version)
266
- path_to_emulators = path_to_class_sz_data + '/' + folder +'/'
267
-
268
- cp_tt_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TT'])
269
-
270
- cp_te_nn[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['TE'])
271
-
272
- with suppress_warnings():
273
- cp_ee_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'TTTEEE/' + emulator_dict[mp]['EE'])
274
-
275
- cp_pp_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'PP/' + emulator_dict[mp]['PP'])
276
-
277
- cp_pknl_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
278
-
279
- cp_pkl_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
280
-
281
- if (mp == 'lcdm') and (dofftlog_alphas == True):
282
- cp_pkl_fftlog_alphas_real_nn[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKLFFTLOG_ALPHAS_REAL']
283
- )
284
- cp_pkl_fftlog_alphas_imag_nn[mp] = Restore_PCAplusNN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKLFFTLOG_ALPHAS_IMAG']
285
- )
286
- cp_pkl_fftlog_alphas_nus[mp] = np.load(path_to_emulators + 'PK/PKL_FFTLog_alphas_nu_v1.npz')
287
-
288
- cp_der_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
289
-
290
- cp_da_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
291
-
292
- cp_h_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
293
-
294
- cp_s8_nn[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
295
-
296
-
297
-
298
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.24
3
+ Version: 0.0.25.post1
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -13,3 +13,4 @@ Requires-Dist: tensorflow
13
13
  Requires-Dist: mcfit
14
14
  Requires-Dist: get_cosmopower_emus
15
15
  Requires-Dist: class_sz_data
16
+ Requires-Dist: cosmopower-jax
@@ -5,7 +5,9 @@ classy_szfast/classy_sz.py
5
5
  classy_szfast/classy_szfast.py
6
6
  classy_szfast/config.py
7
7
  classy_szfast/cosmopower.py
8
+ classy_szfast/cosmopower_jax.py
8
9
  classy_szfast/cosmosis_classy_szfast_interface.py
10
+ classy_szfast/emulators_meta_data.py
9
11
  classy_szfast/pks_and_sigmas.py
10
12
  classy_szfast/restore_nn.py
11
13
  classy_szfast/suppress_warnings.py
@@ -4,3 +4,4 @@ tensorflow
4
4
  mcfit
5
5
  get_cosmopower_emus
6
6
  class_sz_data
7
+ cosmopower-jax
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.24"
6
+ version = "0.0.25.post1"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]
@@ -15,7 +15,8 @@ dependencies = [
15
15
  "tensorflow",
16
16
  "mcfit",
17
17
  "get_cosmopower_emus",
18
- "class_sz_data"
18
+ "class_sz_data",
19
+ "cosmopower-jax"
19
20
  ]
20
21
 
21
22
  [project.urls]