classy-szfast 0.0.25.post12__tar.gz → 0.0.25.post13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/PKG-INFO +1 -1
  2. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/cosmopower_jax.py +27 -0
  3. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast.egg-info/PKG-INFO +1 -1
  4. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/pyproject.toml +1 -1
  5. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/README.md +0 -0
  6. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/__init__.py +0 -0
  7. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/classy_sz.py +0 -0
  8. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/classy_szfast.py +0 -0
  9. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/config.py +0 -0
  10. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/cosmopower.py +0 -0
  11. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  12. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/custom_bias/__init__.py +0 -0
  13. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/custom_bias/custom_bias.py +0 -0
  14. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/custom_profiles/__init__.py +0 -0
  15. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  16. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/emulators_meta_data.py +0 -0
  17. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/pks_and_sigmas.py +0 -0
  18. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/restore_nn.py +0 -0
  19. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/suppress_warnings.py +0 -0
  20. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast/utils.py +0 -0
  21. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast.egg-info/SOURCES.txt +0 -0
  22. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast.egg-info/dependency_links.txt +0 -0
  23. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast.egg-info/requires.txt +0 -0
  24. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/classy_szfast.egg-info/top_level.txt +0 -0
  25. {classy_szfast-0.0.25.post12 → classy_szfast-0.0.25.post13}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post12
3
+ Version: 0.0.25.post13
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -7,6 +7,7 @@ from .suppress_warnings import suppress_warnings
7
7
  from .emulators_meta_data import *
8
8
 
9
9
  from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
10
+ from jax.errors import TracerArrayConversionError
10
11
 
11
12
 
12
13
  cp_tt_nn_jax = {}
@@ -28,6 +29,32 @@ class CosmoPowerJAX_custom(CPJ):
28
29
  if 'ten_to_predictions' in kwargs.keys():
29
30
  self.ten_to_predictions = kwargs['ten_to_predictions']
30
31
 
32
+ def _dict_to_ordered_arr_np(self,
33
+ input_dict,
34
+ ):
35
+ """
36
+ Sort input parameters. Takend verbatim from CP
37
+ (https://github.com/alessiospuriomancini/cosmopower/blob/main/cosmopower/cosmopower_NN.py#LL291C1-L308C73)
38
+
39
+ Parameters:
40
+ input_dict (dict [numpy.ndarray]):
41
+ input dict of (arrays of) parameters to be sorted
42
+
43
+ Returns:
44
+ numpy.ndarray:
45
+ parameters sorted according to desired order
46
+ """
47
+ if self.parameters is not None:
48
+ try:
49
+ return np.stack([input_dict[k] for k in self.parameters], axis=1)
50
+ except TracerArrayConversionError:
51
+ converted_dict = {k: jnp.array(v) if isinstance(v, list) else v for k, v in input_dict.items()}
52
+ return jnp.stack([converted_dict[k] for k in self.parameters], axis=1)
53
+
54
+ else:
55
+ return np.stack([input_dict[k] for k in input_dict], axis=1)
56
+
57
+
31
58
  def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
32
59
  feature_train_mean, feature_train_std, input_vec):
33
60
  """ Forward pass through pre-trained network.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post12
3
+ Version: 0.0.25.post13
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.25.post12"
6
+ version = "0.0.25.post13"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]