jaxspec 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {jaxspec-0.3.0 → jaxspec-0.3.1}/.gitignore +2 -0
  2. {jaxspec-0.3.0 → jaxspec-0.3.1}/PKG-INFO +4 -4
  3. {jaxspec-0.3.0 → jaxspec-0.3.1}/pyproject.toml +5 -5
  4. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/analysis/results.py +3 -1
  5. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/fit/_bayesian_model.py +8 -8
  6. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_fakeit.py +3 -3
  7. {jaxspec-0.3.0 → jaxspec-0.3.1}/.dockerignore +0 -0
  8. {jaxspec-0.3.0 → jaxspec-0.3.1}/.github/dependabot.yml +0 -0
  9. {jaxspec-0.3.0 → jaxspec-0.3.1}/.github/workflows/documentation-links.yml +0 -0
  10. {jaxspec-0.3.0 → jaxspec-0.3.1}/.github/workflows/publish.yml +0 -0
  11. {jaxspec-0.3.0 → jaxspec-0.3.1}/.github/workflows/test-and-coverage.yml +0 -0
  12. {jaxspec-0.3.0 → jaxspec-0.3.1}/.pre-commit-config.yaml +0 -0
  13. {jaxspec-0.3.0 → jaxspec-0.3.1}/.python-version +0 -0
  14. {jaxspec-0.3.0 → jaxspec-0.3.1}/.readthedocs.yaml +0 -0
  15. {jaxspec-0.3.0 → jaxspec-0.3.1}/CODE_OF_CONDUCT.md +0 -0
  16. {jaxspec-0.3.0 → jaxspec-0.3.1}/Dockerfile +0 -0
  17. {jaxspec-0.3.0 → jaxspec-0.3.1}/LICENSE.md +0 -0
  18. {jaxspec-0.3.0 → jaxspec-0.3.1}/README.md +0 -0
  19. {jaxspec-0.3.0 → jaxspec-0.3.1}/codecov.yml +0 -0
  20. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/contribute/index.md +0 -0
  21. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/contribute/internal.md +0 -0
  22. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/contribute/xspec.md +0 -0
  23. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/css/extra.css +0 -0
  24. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/css/material.css +0 -0
  25. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/css/mkdocstrings.css +0 -0
  26. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/css/xarray.css +0 -0
  27. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/dev/index.md +0 -0
  28. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/background.md +0 -0
  29. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/build_model.md +0 -0
  30. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/fakeits.md +0 -0
  31. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/fitting_example.md +0 -0
  32. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/index.md +0 -0
  33. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/background_comparison.png +0 -0
  34. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/background_gp.png +0 -0
  35. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/background_spectral.png +0 -0
  36. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/fakeits.png +0 -0
  37. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/fitting_example_corner.png +0 -0
  38. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/fitting_example_ppc.png +0 -0
  39. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/model.png +0 -0
  40. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/rmf.png +0 -0
  41. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/subtract_background.png +0 -0
  42. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/examples/statics/subtract_background_with_errors.png +0 -0
  43. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/faq/cookbook.md +0 -0
  44. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/faq/index.md +0 -0
  45. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/faq/statics/cstat_vs_chi2.png +0 -0
  46. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/frontpage/installation.md +0 -0
  47. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/index.md +0 -0
  48. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/javascripts/mathjax.js +0 -0
  49. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/logo/logo_small.svg +0 -0
  50. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/logo/xifu_mini.svg +0 -0
  51. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/abundance.md +0 -0
  52. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/additive.md +0 -0
  53. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/background.md +0 -0
  54. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/data.md +0 -0
  55. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/fitting.md +0 -0
  56. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/integrate.md +0 -0
  57. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/model.md +0 -0
  58. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/multiplicative.md +0 -0
  59. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/references/results.md +0 -0
  60. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/runtime/diagram.txt +0 -0
  61. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/runtime/result_table.txt +0 -0
  62. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/runtime/various_model_graphs/diagram.txt +0 -0
  63. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/runtime/various_model_graphs/model_complex.txt +0 -0
  64. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/runtime/various_model_graphs/model_simple.txt +0 -0
  65. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/theory/background.md +0 -0
  66. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/theory/bayesian_inference.md +0 -0
  67. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/theory/index.md +0 -0
  68. {jaxspec-0.3.0 → jaxspec-0.3.1}/docs/theory/instrument.md +0 -0
  69. {jaxspec-0.3.0 → jaxspec-0.3.1}/mkdocs.yml +0 -0
  70. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/__init__.py +0 -0
  71. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/analysis/__init__.py +0 -0
  72. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/analysis/_plot.py +0 -0
  73. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/analysis/compare.py +0 -0
  74. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/data/__init__.py +0 -0
  75. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/data/instrument.py +0 -0
  76. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/data/obsconf.py +0 -0
  77. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/data/observation.py +0 -0
  78. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/data/ogip.py +0 -0
  79. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/data/util.py +0 -0
  80. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/experimental/__init__.py +0 -0
  81. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/experimental/interpolator.py +0 -0
  82. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/experimental/interpolator_jax.py +0 -0
  83. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/experimental/intrument_models.py +0 -0
  84. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/experimental/nested_sampler.py +0 -0
  85. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/experimental/tabulated.py +0 -0
  86. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/fit/__init__.py +0 -0
  87. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/fit/_build_model.py +0 -0
  88. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/fit/_fitter.py +0 -0
  89. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/__init__.py +0 -0
  90. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/_graph_util.py +0 -0
  91. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/abc.py +0 -0
  92. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/additive.py +0 -0
  93. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/background.py +0 -0
  94. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/instrument.py +0 -0
  95. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/list.py +0 -0
  96. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/model/multiplicative.py +0 -0
  97. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/scripts/__init__.py +0 -0
  98. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/scripts/debug.py +0 -0
  99. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/util/__init__.py +0 -0
  100. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/util/abundance.py +0 -0
  101. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/util/integrate.py +0 -0
  102. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/util/misc.py +0 -0
  103. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/util/online_storage.py +0 -0
  104. {jaxspec-0.3.0 → jaxspec-0.3.1}/src/jaxspec/util/typing.py +0 -0
  105. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/conftest.py +0 -0
  106. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/data_files.yml +0 -0
  107. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/data_hash.yml +0 -0
  108. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_background.py +0 -0
  109. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_bayesian_model.py +0 -0
  110. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_bayesian_model_building.py +0 -0
  111. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_instruments.py +0 -0
  112. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_integrate.py +0 -0
  113. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_mcmc.py +0 -0
  114. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_misc.py +0 -0
  115. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_models.py +0 -0
  116. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_observation.py +0 -0
  117. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_repr.py +0 -0
  118. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_results.py +0 -0
  119. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_xspec.py +0 -0
  120. {jaxspec-0.3.0 → jaxspec-0.3.1}/tests/test_xspec_models.py +0 -0
@@ -197,4 +197,6 @@ cython_debug/
197
197
  .Trashes
198
198
  ehthumbs.db
199
199
  Thumbs.db
200
+
201
+ _old/
200
202
  # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -15,18 +15,18 @@ Requires-Dist: chainconsumer<2,>=1.1.2
15
15
  Requires-Dist: cmasher<2,>=1.6.3
16
16
  Requires-Dist: flax>0.10.5
17
17
  Requires-Dist: interpax<0.4,>=0.3.5
18
- Requires-Dist: jax<0.6,>=0.5.0
18
+ Requires-Dist: jax<0.7,>=0.5.0
19
19
  Requires-Dist: jaxns<3,>=2.6.7
20
20
  Requires-Dist: jaxopt<0.9,>=0.8.3
21
21
  Requires-Dist: matplotlib<4,>=3.8.0
22
22
  Requires-Dist: mendeleev<1.2,>=0.15
23
23
  Requires-Dist: networkx~=3.1
24
24
  Requires-Dist: numpy<3.0.0
25
- Requires-Dist: numpyro<0.19,>=0.17.0
25
+ Requires-Dist: numpyro<0.20,>=0.17.0
26
26
  Requires-Dist: optimistix<0.0.11,>=0.0.10
27
27
  Requires-Dist: pandas<3,>=2.2.0
28
28
  Requires-Dist: pooch<2,>=1.8.2
29
- Requires-Dist: scipy<1.15
29
+ Requires-Dist: scipy<1.16
30
30
  Requires-Dist: seaborn<0.14,>=0.13.1
31
31
  Requires-Dist: simpleeval<1.1.0,>=0.9.13
32
32
  Requires-Dist: sparse>0.15
@@ -1,17 +1,17 @@
1
1
  [project]
2
2
  name = "jaxspec"
3
- version = "0.3.0"
3
+ version = "0.3.1"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = [{ name = "sdupourque", email = "sdupourque@irap.omp.eu" }]
6
6
  requires-python = ">=3.10,<3.13"
7
7
  readme = "README.md"
8
8
  license = "MIT"
9
9
  dependencies = [
10
- "jax>=0.5.0,<0.6",
10
+ "jax>=0.5.0,<0.7",
11
11
  "numpy<3.0.0",
12
12
  "pandas>=2.2.0,<3",
13
13
  "astropy>=6.0.0,<8",
14
- "numpyro>=0.17.0,<0.19",
14
+ "numpyro>=0.17.0,<0.20",
15
15
  "networkx~=3.1",
16
16
  "matplotlib>=3.8.0,<4",
17
17
  "arviz>=0.17.1,<0.23.0",
@@ -23,7 +23,7 @@ dependencies = [
23
23
  "seaborn>=0.13.1,<0.14",
24
24
  "sparse>0.15",
25
25
  "optimistix>=0.0.10,<0.0.11",
26
- "scipy<1.15",
26
+ "scipy<1.16",
27
27
  "mendeleev>=0.15,<1.2",
28
28
  "jaxns>=2.6.7,<3",
29
29
  "pooch>=1.8.2,<2",
@@ -58,7 +58,7 @@ test = [
58
58
  ]
59
59
  dev = [
60
60
  "pre-commit>=3.5,<5.0",
61
- "ruff>=0.2.1,<0.10.0",
61
+ "ruff>=0.2.1,<0.15.0",
62
62
  "jupyterlab>=4.0.7,<5",
63
63
  "notebook>=7.0.6,<8",
64
64
  "ipywidgets>=8.1.1,<9",
@@ -122,7 +122,9 @@ class FitResult:
122
122
 
123
123
  samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
124
124
 
125
- total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
125
+ total_shape = tuple(
126
+ posterior.sizes[d] for d in posterior.coords if not (("obs" in d) or ("bkg" in d))
127
+ )
126
128
 
127
129
  posterior = {key: posterior[key].data for key in posterior.data_vars}
128
130
 
@@ -10,9 +10,8 @@ import matplotlib.pyplot as plt
10
10
  import numpyro
11
11
 
12
12
  from flax import nnx
13
- from jax.experimental import mesh_utils
14
13
  from jax.random import PRNGKey
15
- from jax.sharding import PositionalSharding
14
+ from jax.sharding import NamedSharding, PartitionSpec
16
15
  from numpyro.distributions import Poisson, TransformedDistribution
17
16
  from numpyro.infer import Predictive
18
17
  from numpyro.infer.inspect import get_model_relations
@@ -244,7 +243,7 @@ class BayesianModel(nnx.Module):
244
243
  return log_posterior_prob
245
244
 
246
245
  @cached_property
247
- def _parameter_names(self) -> list[str]:
246
+ def parameter_names(self) -> list[str]:
248
247
  """
249
248
  A list of parameter names for the model.
250
249
  """
@@ -269,7 +268,7 @@ class BayesianModel(nnx.Module):
269
268
  """
270
269
  input_params = {}
271
270
 
272
- for index, key in enumerate(self._parameter_names):
271
+ for index, key in enumerate(self.parameter_names):
273
272
  input_params[key] = theta[index]
274
273
 
275
274
  return input_params
@@ -279,9 +278,9 @@ class BayesianModel(nnx.Module):
279
278
  Convert a dictionary of parameters to an array of parameters.
280
279
  """
281
280
 
282
- theta = jnp.zeros(len(self._parameter_names))
281
+ theta = jnp.zeros(len(self.parameter_names))
283
282
 
284
- for index, key in enumerate(self._parameter_names):
283
+ for index, key in enumerate(self.parameter_names):
285
284
  theta = theta.at[index].set(dict_of_params[key])
286
285
 
287
286
  return theta
@@ -298,7 +297,7 @@ class BayesianModel(nnx.Module):
298
297
  @jax.jit
299
298
  def prior_sample(key):
300
299
  return Predictive(
301
- self.numpyro_model, return_sites=self._parameter_names, num_samples=num_samples
300
+ self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
302
301
  )(key, observed=False)
303
302
 
304
303
  return prior_sample(key)
@@ -324,7 +323,8 @@ class BayesianModel(nnx.Module):
324
323
  """
325
324
  key_prior, key_posterior = jax.random.split(key, 2)
326
325
  n_devices = len(jax.local_devices())
327
- sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
326
+ mesh = jax.make_mesh((n_devices,), ("batch",))
327
+ sharding = NamedSharding(mesh, PartitionSpec("batch"))
328
328
 
329
329
  # Sample from prior and correct if the number of samples is not a multiple of the number of devices
330
330
  if num_samples % n_devices != 0:
@@ -63,10 +63,10 @@ def model():
63
63
 
64
64
  @pytest.fixture
65
65
  def sharded_parameters(unidimensional_parameters):
66
- from jax.experimental import mesh_utils
67
- from jax.sharding import PositionalSharding
66
+ from jax.sharding import NamedSharding, PartitionSpec
68
67
 
69
- sharding = PositionalSharding(mesh_utils.create_device_mesh((4,)))
68
+ mesh = jax.make_mesh((4,), ("batch",))
69
+ sharding = NamedSharding(mesh, PartitionSpec("batch"))
70
70
 
71
71
  return jax.device_put(unidimensional_parameters, sharding)
72
72
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes