classy-szfast 0.0.25.post12__tar.gz → 0.0.25.post14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/PKG-INFO +2 -2
  2. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/classy_szfast.py +30 -1
  3. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/cosmopower_jax.py +27 -0
  4. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast.egg-info/PKG-INFO +2 -2
  5. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/pyproject.toml +1 -1
  6. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/README.md +0 -0
  7. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/__init__.py +0 -0
  8. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/classy_sz.py +0 -0
  9. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/config.py +0 -0
  10. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/cosmopower.py +0 -0
  11. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  12. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/custom_bias/__init__.py +0 -0
  13. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/custom_bias/custom_bias.py +0 -0
  14. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/custom_profiles/__init__.py +0 -0
  15. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  16. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/emulators_meta_data.py +0 -0
  17. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/pks_and_sigmas.py +0 -0
  18. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/restore_nn.py +0 -0
  19. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/suppress_warnings.py +0 -0
  20. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast/utils.py +0 -0
  21. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast.egg-info/SOURCES.txt +0 -0
  22. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast.egg-info/dependency_links.txt +0 -0
  23. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast.egg-info/requires.txt +0 -0
  24. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/classy_szfast.egg-info/top_level.txt +0 -0
  25. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post14}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post12
3
+ Version: 0.0.25.post14
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -173,9 +173,16 @@ class Class_szfast(object):
173
173
  self.cp_pkl_fftlog_alphas_imag_nn = cp_pkl_fftlog_alphas_imag_nn
174
174
 
175
175
  self.cosmo_model = 'ede-v2'
176
+
176
177
  self.use_Amod = 0
177
178
  self.Amod = 0
178
-
179
+
180
+ self.use_pk_z_bins = 0
181
+ self.pk_z_bins_z1 = 0
182
+ self.pk_z_bins_z2 = 0
183
+ self.pk_z_bins_A0 = 0
184
+ self.pk_z_bins_A1 = 0
185
+ self.pk_z_bins_A2 = 0
179
186
 
180
187
 
181
188
  if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
@@ -223,6 +230,15 @@ class Class_szfast(object):
223
230
  self.use_Amod = v
224
231
  self.Amod = params_settings['Amod']
225
232
 
233
+ if k == 'use_pk_z_bins':
234
+ self.use_pk_z_bins = v
235
+ self.pk_z_bins_z1 = params_settings['pk_z_bins_z1']
236
+ self.pk_z_bins_z2 = params_settings['pk_z_bins_z2']
237
+ self.pk_z_bins_A0 = params_settings['pk_z_bins_A0']
238
+ self.pk_z_bins_A1 = params_settings['pk_z_bins_A1']
239
+ self.pk_z_bins_A2 = params_settings['pk_z_bins_A2']
240
+
241
+
226
242
 
227
243
 
228
244
  if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
@@ -535,6 +551,19 @@ class Class_szfast(object):
535
551
  pk_ae = pkl_p + self.Amod*(pknl_p-pkl_p)
536
552
  predicted_pk_spectrum_z.append(pk_ae)
537
553
 
554
+ elif self.use_pk_z_bins:
555
+ # print('>>> using pk_z_bins')
556
+ for zp in z_arr:
557
+ params_dict_pp = params_dict.copy()
558
+ params_dict_pp['z_pk_save_nonclass'] = [zp]
559
+ pkl_p = self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0]
560
+ if zp < self.pk_z_bins_z1:
561
+ pklp = self.pk_z_bins_A0*pkl_p
562
+ elif zp < self.pk_z_bins_z2:
563
+ pklp = self.pk_z_bins_A1*pkl_p
564
+ else:
565
+ pklp = self.pk_z_bins_A2*pkl_p
566
+ predicted_pk_spectrum_z.append(pklp)
538
567
  else:
539
568
 
540
569
  for zp in z_arr:
@@ -7,6 +7,7 @@ from .suppress_warnings import suppress_warnings
7
7
  from .emulators_meta_data import *
8
8
 
9
9
  from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
10
+ from jax.errors import TracerArrayConversionError
10
11
 
11
12
 
12
13
  cp_tt_nn_jax = {}
@@ -28,6 +29,32 @@ class CosmoPowerJAX_custom(CPJ):
28
29
  if 'ten_to_predictions' in kwargs.keys():
29
30
  self.ten_to_predictions = kwargs['ten_to_predictions']
30
31
 
32
+ def _dict_to_ordered_arr_np(self,
33
+ input_dict,
34
+ ):
35
+ """
36
+ Sort input parameters. Takend verbatim from CP
37
+ (https://github.com/alessiospuriomancini/cosmopower/blob/main/cosmopower/cosmopower_NN.py#LL291C1-L308C73)
38
+
39
+ Parameters:
40
+ input_dict (dict [numpy.ndarray]):
41
+ input dict of (arrays of) parameters to be sorted
42
+
43
+ Returns:
44
+ numpy.ndarray:
45
+ parameters sorted according to desired order
46
+ """
47
+ if self.parameters is not None:
48
+ try:
49
+ return np.stack([input_dict[k] for k in self.parameters], axis=1)
50
+ except TracerArrayConversionError:
51
+ converted_dict = {k: jnp.array(v) if isinstance(v, list) else v for k, v in input_dict.items()}
52
+ return jnp.stack([converted_dict[k] for k in self.parameters], axis=1)
53
+
54
+ else:
55
+ return np.stack([input_dict[k] for k in input_dict], axis=1)
56
+
57
+
31
58
  def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
32
59
  feature_train_mean, feature_train_std, input_vec):
33
60
  """ Forward pass through pre-trained network.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post12
3
+ Version: 0.0.25.post14
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.25.post12"
6
+ version = "0.0.25.post14"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]