jaxspec 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -122,7 +122,9 @@ class FitResult:
122
122
 
123
123
  samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
124
124
 
125
- total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
125
+ total_shape = tuple(
126
+ posterior.sizes[d] for d in posterior.coords if not (("obs" in d) or ("bkg" in d))
127
+ )
126
128
 
127
129
  posterior = {key: posterior[key].data for key in posterior.data_vars}
128
130
 
@@ -10,9 +10,8 @@ import matplotlib.pyplot as plt
10
10
  import numpyro
11
11
 
12
12
  from flax import nnx
13
- from jax.experimental import mesh_utils
14
13
  from jax.random import PRNGKey
15
- from jax.sharding import PositionalSharding
14
+ from jax.sharding import NamedSharding, PartitionSpec
16
15
  from numpyro.distributions import Poisson, TransformedDistribution
17
16
  from numpyro.infer import Predictive
18
17
  from numpyro.infer.inspect import get_model_relations
@@ -244,7 +243,7 @@ class BayesianModel(nnx.Module):
244
243
  return log_posterior_prob
245
244
 
246
245
  @cached_property
247
- def _parameter_names(self) -> list[str]:
246
+ def parameter_names(self) -> list[str]:
248
247
  """
249
248
  A list of parameter names for the model.
250
249
  """
@@ -269,7 +268,7 @@ class BayesianModel(nnx.Module):
269
268
  """
270
269
  input_params = {}
271
270
 
272
- for index, key in enumerate(self._parameter_names):
271
+ for index, key in enumerate(self.parameter_names):
273
272
  input_params[key] = theta[index]
274
273
 
275
274
  return input_params
@@ -279,9 +278,9 @@ class BayesianModel(nnx.Module):
279
278
  Convert a dictionary of parameters to an array of parameters.
280
279
  """
281
280
 
282
- theta = jnp.zeros(len(self._parameter_names))
281
+ theta = jnp.zeros(len(self.parameter_names))
283
282
 
284
- for index, key in enumerate(self._parameter_names):
283
+ for index, key in enumerate(self.parameter_names):
285
284
  theta = theta.at[index].set(dict_of_params[key])
286
285
 
287
286
  return theta
@@ -298,7 +297,7 @@ class BayesianModel(nnx.Module):
298
297
  @jax.jit
299
298
  def prior_sample(key):
300
299
  return Predictive(
301
- self.numpyro_model, return_sites=self._parameter_names, num_samples=num_samples
300
+ self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
302
301
  )(key, observed=False)
303
302
 
304
303
  return prior_sample(key)
@@ -324,7 +323,8 @@ class BayesianModel(nnx.Module):
324
323
  """
325
324
  key_prior, key_posterior = jax.random.split(key, 2)
326
325
  n_devices = len(jax.local_devices())
327
- sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
326
+ mesh = jax.make_mesh((n_devices,), ("batch",))
327
+ sharding = NamedSharding(mesh, PartitionSpec("batch"))
328
328
 
329
329
  # Sample from prior and correct if the number of samples is not a multiple of the number of devices
330
330
  if num_samples % n_devices != 0:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -15,18 +15,18 @@ Requires-Dist: chainconsumer<2,>=1.1.2
15
15
  Requires-Dist: cmasher<2,>=1.6.3
16
16
  Requires-Dist: flax>0.10.5
17
17
  Requires-Dist: interpax<0.4,>=0.3.5
18
- Requires-Dist: jax<0.6,>=0.5.0
18
+ Requires-Dist: jax<0.7,>=0.5.0
19
19
  Requires-Dist: jaxns<3,>=2.6.7
20
20
  Requires-Dist: jaxopt<0.9,>=0.8.3
21
21
  Requires-Dist: matplotlib<4,>=3.8.0
22
22
  Requires-Dist: mendeleev<1.2,>=0.15
23
23
  Requires-Dist: networkx~=3.1
24
24
  Requires-Dist: numpy<3.0.0
25
- Requires-Dist: numpyro<0.19,>=0.17.0
25
+ Requires-Dist: numpyro<0.20,>=0.17.0
26
26
  Requires-Dist: optimistix<0.0.11,>=0.0.10
27
27
  Requires-Dist: pandas<3,>=2.2.0
28
28
  Requires-Dist: pooch<2,>=1.8.2
29
- Requires-Dist: scipy<1.15
29
+ Requires-Dist: scipy<1.16
30
30
  Requires-Dist: seaborn<0.14,>=0.13.1
31
31
  Requires-Dist: simpleeval<1.1.0,>=0.9.13
32
32
  Requires-Dist: sparse>0.15
@@ -2,7 +2,7 @@ jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
2
2
  jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  jaxspec/analysis/_plot.py,sha256=0xEz-e_xk7XvU6PUfbNwxaWg1-SxAF2XAqhkxWEhIFs,6239
4
4
  jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
5
- jaxspec/analysis/results.py,sha256=_qwDSsThI7FOAR6nMaJltGWlKO5Sz2wc1EQ73Y0Ghho,26013
5
+ jaxspec/analysis/results.py,sha256=tIBWmLoX43EY2BXt50ec8A-DqQ98PMd3m-FqTRT4iRE,26073
6
6
  jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
7
7
  jaxspec/data/instrument.py,sha256=RDiG_LkucvnF2XE_ghTFME6d_2YirgQUcEY0gEle6dk,4775
8
8
  jaxspec/data/obsconf.py,sha256=G0RwNshvbDQzw_ba8Y8NdI-cRsgEj-OlSNdeYCANqVM,10484
@@ -16,7 +16,7 @@ jaxspec/experimental/intrument_models.py,sha256=vuRw7xypPI9YV-Hv8chVNP4ti24dCGjb
16
16
  jaxspec/experimental/nested_sampler.py,sha256=8jCAXQAe2mD5YSNSF0jia_rFWES_MzwRM3FrQQS_x7w,2807
17
17
  jaxspec/experimental/tabulated.py,sha256=H0llUiso2KGH4xUzTUSVPy-6I8D3wm707lU_Z1P5uq4,9429
18
18
  jaxspec/fit/__init__.py,sha256=OaS0-Hkb3Hd-AkE2o-KWfoWMX0NSCPY-_FP2znHf9l0,153
19
- jaxspec/fit/_bayesian_model.py,sha256=BeYukXr86Y1kEmSyiv-6QC4M2rM78Kx_MgGecu4ML98,15179
19
+ jaxspec/fit/_bayesian_model.py,sha256=jSCzAzoAhsmUX7mKUikbUR9A1ZNIaY6rdPOxq6OZSU0,15179
20
20
  jaxspec/fit/_build_model.py,sha256=pNZVuVfwOq3Pg23opH7xRv28DsSkQZpvy2Z-1hQSfNs,3219
21
21
  jaxspec/fit/_fitter.py,sha256=doBTJqTP5CN1OJhZHVlS3oMVOzPJyH4YqOnGevIIU68,8893
22
22
  jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -35,8 +35,8 @@ jaxspec/util/integrate.py,sha256=7GwBSagmDzsF3P53tPs-oakeq0zHEwmZZS2zQlXngbE,463
35
35
  jaxspec/util/misc.py,sha256=O3qorCL1Y2X1BS2jdd36C1eDHK9QDXTSOr9kj3uqcJo,654
36
36
  jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
37
37
  jaxspec/util/typing.py,sha256=ZQM_l68qyYnIBZPz_1mKvwPMx64jvVBD8Uj6bx9sHv0,140
38
- jaxspec-0.3.0.dist-info/METADATA,sha256=92shp3kcwQIbKTSVSD7SU68InowsGVZXST0uJYvRwnQ,4199
39
- jaxspec-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
40
- jaxspec-0.3.0.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
41
- jaxspec-0.3.0.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
42
- jaxspec-0.3.0.dist-info/RECORD,,
38
+ jaxspec-0.3.1.dist-info/METADATA,sha256=8i1cuzZY4iwIjWEhIPXBiC-Z8Y4Vv27omIMwTKnoPwo,4199
39
+ jaxspec-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
40
+ jaxspec-0.3.1.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
41
+ jaxspec-0.3.1.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
42
+ jaxspec-0.3.1.dist-info/RECORD,,