jaxspec 0.2.2.dev0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.github/workflows/test-and-coverage.yml +1 -1
  2. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.gitignore +2 -0
  3. jaxspec-0.3.1/.python-version +1 -0
  4. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.readthedocs.yaml +1 -1
  5. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/PKG-INFO +11 -11
  6. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/README.md +1 -1
  7. jaxspec-0.3.1/codecov.yml +2 -0
  8. jaxspec-0.3.1/docs/css/extra.css +13 -0
  9. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/faq/index.md +5 -0
  10. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/index.md +15 -7
  11. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/fitting.md +0 -15
  12. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/mkdocs.yml +2 -1
  13. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/pyproject.toml +11 -11
  14. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/analysis/_plot.py +5 -5
  15. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/analysis/results.py +41 -26
  16. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/data/obsconf.py +9 -3
  17. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/data/observation.py +3 -1
  18. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/data/ogip.py +9 -2
  19. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/data/util.py +17 -11
  20. jaxspec-0.3.1/src/jaxspec/experimental/interpolator.py +74 -0
  21. jaxspec-0.3.1/src/jaxspec/experimental/interpolator_jax.py +79 -0
  22. jaxspec-0.3.1/src/jaxspec/experimental/intrument_models.py +159 -0
  23. jaxspec-0.3.1/src/jaxspec/experimental/nested_sampler.py +78 -0
  24. jaxspec-0.3.1/src/jaxspec/experimental/tabulated.py +264 -0
  25. jaxspec-0.3.1/src/jaxspec/fit/__init__.py +3 -0
  26. jaxspec-0.2.2.dev0/src/jaxspec/fit.py → jaxspec-0.3.1/src/jaxspec/fit/_bayesian_model.py +84 -336
  27. {jaxspec-0.2.2.dev0/src/jaxspec/_fit → jaxspec-0.3.1/src/jaxspec/fit}/_build_model.py +42 -6
  28. jaxspec-0.3.1/src/jaxspec/fit/_fitter.py +255 -0
  29. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/abc.py +52 -80
  30. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/additive.py +14 -5
  31. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/background.py +17 -14
  32. jaxspec-0.3.1/src/jaxspec/model/instrument.py +81 -0
  33. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/list.py +4 -1
  34. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/multiplicative.py +32 -12
  35. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/util/integrate.py +17 -5
  36. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_bayesian_model.py +0 -20
  37. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_bayesian_model_building.py +88 -1
  38. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_fakeit.py +3 -3
  39. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_integrate.py +16 -8
  40. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_mcmc.py +5 -1
  41. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.dockerignore +0 -0
  42. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.github/dependabot.yml +0 -0
  43. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.github/workflows/documentation-links.yml +0 -0
  44. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.github/workflows/publish.yml +0 -0
  45. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/.pre-commit-config.yaml +0 -0
  46. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/CODE_OF_CONDUCT.md +0 -0
  47. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/Dockerfile +0 -0
  48. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/LICENSE.md +0 -0
  49. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/contribute/index.md +0 -0
  50. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/contribute/internal.md +0 -0
  51. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/contribute/xspec.md +0 -0
  52. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/css/material.css +0 -0
  53. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/css/mkdocstrings.css +0 -0
  54. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/css/xarray.css +0 -0
  55. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/dev/index.md +0 -0
  56. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/background.md +0 -0
  57. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/build_model.md +0 -0
  58. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/fakeits.md +0 -0
  59. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/fitting_example.md +0 -0
  60. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/index.md +0 -0
  61. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/background_comparison.png +0 -0
  62. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/background_gp.png +0 -0
  63. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/background_spectral.png +0 -0
  64. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/fakeits.png +0 -0
  65. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/fitting_example_corner.png +0 -0
  66. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/fitting_example_ppc.png +0 -0
  67. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/model.png +0 -0
  68. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/rmf.png +0 -0
  69. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/subtract_background.png +0 -0
  70. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/examples/statics/subtract_background_with_errors.png +0 -0
  71. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/faq/cookbook.md +0 -0
  72. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/faq/statics/cstat_vs_chi2.png +0 -0
  73. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/frontpage/installation.md +0 -0
  74. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/javascripts/mathjax.js +0 -0
  75. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/logo/logo_small.svg +0 -0
  76. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/logo/xifu_mini.svg +0 -0
  77. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/abundance.md +0 -0
  78. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/additive.md +0 -0
  79. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/background.md +0 -0
  80. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/data.md +0 -0
  81. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/integrate.md +0 -0
  82. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/model.md +0 -0
  83. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/multiplicative.md +0 -0
  84. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/references/results.md +0 -0
  85. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/runtime/diagram.txt +0 -0
  86. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/runtime/result_table.txt +0 -0
  87. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/runtime/various_model_graphs/diagram.txt +0 -0
  88. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/runtime/various_model_graphs/model_complex.txt +0 -0
  89. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/runtime/various_model_graphs/model_simple.txt +0 -0
  90. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/theory/background.md +0 -0
  91. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/theory/bayesian_inference.md +0 -0
  92. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/theory/index.md +0 -0
  93. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/docs/theory/instrument.md +0 -0
  94. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/__init__.py +0 -0
  95. {jaxspec-0.2.2.dev0/src/jaxspec/_fit → jaxspec-0.3.1/src/jaxspec/analysis}/__init__.py +0 -0
  96. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/analysis/compare.py +0 -0
  97. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/data/__init__.py +0 -0
  98. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/data/instrument.py +0 -0
  99. {jaxspec-0.2.2.dev0/src/jaxspec/analysis → jaxspec-0.3.1/src/jaxspec/experimental}/__init__.py +0 -0
  100. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/__init__.py +0 -0
  101. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/model/_graph_util.py +0 -0
  102. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/scripts/__init__.py +0 -0
  103. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/scripts/debug.py +0 -0
  104. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/util/__init__.py +0 -0
  105. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/util/abundance.py +0 -0
  106. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/util/misc.py +0 -0
  107. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/util/online_storage.py +0 -0
  108. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/src/jaxspec/util/typing.py +0 -0
  109. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/conftest.py +0 -0
  110. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/data_files.yml +0 -0
  111. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/data_hash.yml +0 -0
  112. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_background.py +0 -0
  113. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_instruments.py +0 -0
  114. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_misc.py +0 -0
  115. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_models.py +0 -0
  116. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_observation.py +0 -0
  117. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_repr.py +0 -0
  118. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_results.py +0 -0
  119. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_xspec.py +0 -0
  120. {jaxspec-0.2.2.dev0 → jaxspec-0.3.1}/tests/test_xspec_models.py +0 -0
@@ -42,7 +42,7 @@ jobs:
42
42
 
43
43
  - name: "Upload coverage to Codecov"
44
44
  if: steps.filter.outputs.src == 'true'
45
- uses: codecov/codecov-action@v4
45
+ uses: codecov/codecov-action@v5
46
46
  with:
47
47
  token: ${{ secrets.CODECOV_TOKEN }}
48
48
  fail_ci_if_error: true
@@ -197,4 +197,6 @@ cython_debug/
197
197
  .Trashes
198
198
  ehthumbs.db
199
199
  Thumbs.db
200
+
201
+ _old/
200
202
  # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
@@ -0,0 +1 @@
1
+ >=3.10, <3.13
@@ -20,7 +20,7 @@ build:
20
20
  - uv pip install -r pyproject.toml
21
21
  # Using insiders versions of mkdocs-material & mkdocstrings
22
22
  - uv pip uninstall mkdocs-material # mkdocstrings mkdocstrings-python
23
- - uv pip install git+https://$GH_TOKEN@github.com/squidfunk/mkdocs-material-insiders.git@9.5.36-insiders-4.53.13
23
+ # - uv pip install git+https://$GH_TOKEN@github.com/squidfunk/mkdocs-material-insiders.git@9.5.36-insiders-4.53.13
24
24
  - uv pip install mkdocstrings mkdocstrings-python
25
25
  - uv pip install mkdocs-autorefs
26
26
  - uv pip install mkdocs-jupyter # This is bugged, I enforced it manually, let's see if it works
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.2.2.dev0
3
+ Version: 0.3.1
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -8,28 +8,28 @@ Author-email: sdupourque <sdupourque@irap.omp.eu>
8
8
  License-Expression: MIT
9
9
  License-File: LICENSE.md
10
10
  Requires-Python: <3.13,>=3.10
11
- Requires-Dist: arviz<0.21.0,>=0.17.1
12
- Requires-Dist: astropy<7,>=6.0.0
11
+ Requires-Dist: arviz<0.23.0,>=0.17.1
12
+ Requires-Dist: astropy<8,>=6.0.0
13
13
  Requires-Dist: catppuccin<3,>=2.3.4
14
14
  Requires-Dist: chainconsumer<2,>=1.1.2
15
15
  Requires-Dist: cmasher<2,>=1.6.3
16
- Requires-Dist: flax<0.11,>=0.10.3
16
+ Requires-Dist: flax>0.10.5
17
17
  Requires-Dist: interpax<0.4,>=0.3.5
18
- Requires-Dist: jax<0.6,>=0.5.0
18
+ Requires-Dist: jax<0.7,>=0.5.0
19
19
  Requires-Dist: jaxns<3,>=2.6.7
20
20
  Requires-Dist: jaxopt<0.9,>=0.8.3
21
21
  Requires-Dist: matplotlib<4,>=3.8.0
22
- Requires-Dist: mendeleev<0.20,>=0.15
22
+ Requires-Dist: mendeleev<1.2,>=0.15
23
23
  Requires-Dist: networkx~=3.1
24
- Requires-Dist: numpy<2.0.0
25
- Requires-Dist: numpyro<0.18,>=0.17.0
24
+ Requires-Dist: numpy<3.0.0
25
+ Requires-Dist: numpyro<0.20,>=0.17.0
26
26
  Requires-Dist: optimistix<0.0.11,>=0.0.10
27
27
  Requires-Dist: pandas<3,>=2.2.0
28
28
  Requires-Dist: pooch<2,>=1.8.2
29
- Requires-Dist: scipy<1.15
29
+ Requires-Dist: scipy<1.16
30
30
  Requires-Dist: seaborn<0.14,>=0.13.1
31
31
  Requires-Dist: simpleeval<1.1.0,>=0.9.13
32
- Requires-Dist: sparse<0.16,>=0.15.4
32
+ Requires-Dist: sparse>0.15
33
33
  Requires-Dist: tinygp<0.4,>=0.3.0
34
34
  Requires-Dist: watermark<3,>=2.4.3
35
35
  Description-Content-Type: text/markdown
@@ -44,7 +44,7 @@ Description-Content-Type: text/markdown
44
44
 
45
45
 
46
46
  [![PyPI - Version](https://img.shields.io/pypi/v/jaxspec?style=for-the-badge&logo=pypi&color=rgb(37%2C%20150%2C%20190))](https://pypi.org/project/jaxspec/)
47
- [![Python package](https://img.shields.io/pypi/pyversions/jaxspec?style=for-the-badge)](https://pypi.org/project/jaxspec/)
47
+ ![Python Version from PEP 621 TOML](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Frenecotyfanboy%2Fjaxspec%2Frefs%2Fheads%2Fmain%2Fpyproject.toml&style=for-the-badge)
48
48
  [![Read the Docs](https://img.shields.io/readthedocs/jaxspec?style=for-the-badge)](https://jaxspec.readthedocs.io/en/latest/)
49
49
  [![Codecov](https://img.shields.io/codecov/c/github/renecotyfanboy/jaxspec?style=for-the-badge)](https://app.codecov.io/gh/renecotyfanboy/jaxspec)
50
50
  [![Slack](https://img.shields.io/badge/Slack-4A154B?style=for-the-badge&logo=slack&logoColor=white)](https://join.slack.com/t/jaxspec/shared_invite/zt-2cuxkdl2f-t0EEAKP~HBEHKvIUZJL2sg)
@@ -8,7 +8,7 @@
8
8
 
9
9
 
10
10
  [![PyPI - Version](https://img.shields.io/pypi/v/jaxspec?style=for-the-badge&logo=pypi&color=rgb(37%2C%20150%2C%20190))](https://pypi.org/project/jaxspec/)
11
- [![Python package](https://img.shields.io/pypi/pyversions/jaxspec?style=for-the-badge)](https://pypi.org/project/jaxspec/)
11
+ ![Python Version from PEP 621 TOML](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Frenecotyfanboy%2Fjaxspec%2Frefs%2Fheads%2Fmain%2Fpyproject.toml&style=for-the-badge)
12
12
  [![Read the Docs](https://img.shields.io/readthedocs/jaxspec?style=for-the-badge)](https://jaxspec.readthedocs.io/en/latest/)
13
13
  [![Codecov](https://img.shields.io/codecov/c/github/renecotyfanboy/jaxspec?style=for-the-badge)](https://app.codecov.io/gh/renecotyfanboy/jaxspec)
14
14
  [![Slack](https://img.shields.io/badge/Slack-4A154B?style=for-the-badge&logo=slack&logoColor=white)](https://join.slack.com/t/jaxspec/shared_invite/zt-2cuxkdl2f-t0EEAKP~HBEHKvIUZJL2sg)
@@ -0,0 +1,2 @@
1
+ ignore:
2
+ - "src/jaxspec/experimental"
@@ -0,0 +1,13 @@
1
+ a.card-link center {
2
+ width: fit-content;
3
+ display: block;
4
+ margin: 0 auto;
5
+ }
6
+
7
+ a.card-link .card-title {
8
+ font-size: 1.5em;
9
+ }
10
+
11
+ a.card-link {
12
+ color: initial; // Make cards have white text until hovered.
13
+ }
@@ -1,5 +1,10 @@
1
1
  # Frequently asked questions
2
2
 
3
+ ## How can I load multiple spectra to fit ?
4
+
5
+ Simply pass a list or dictionnary of [`ObsConfiguration`][jaxspec.data.ObsConfiguration] objects when building your
6
+ fitter object.
7
+
3
8
  ## Why should I use `jaxspec` over `xspec` or associated ?
4
9
 
5
10
  We have taken great care to make `jaxspec` as easy to use as possible. It can be installed with `pip install jaxspec`
@@ -11,13 +11,21 @@ by combining components, and fit it to one or multiple observed spectra using Ba
11
11
 
12
12
  ## Getting started
13
13
 
14
- <div class="grid">
15
-
16
- <a href="frontpage/installation/" class="card" style="font-size: 1.2em;">🛠️ Installation</a>
17
- <a href="examples/fitting_example/" class="card" style="font-size: 1.2em;">🚀 Quickstart</a>
18
- <a href="examples/" class="card" style="font-size: 1.2em;">📚 Examples</a>
19
- <a href="contribute/" class="card" style="font-size: 1.2em;">🤝 Contribute</a>
20
14
 
15
+ <div class="grid cards" markdown>
16
+
17
+ - <a class="card-link" href="frontpage/installation/" target="_blank" rel="noreferrer">
18
+ <span class="card-title center">🛠️ Installation</span>
19
+ </a>
20
+ - <a class="card-link" href="examples/fitting_example/" target="_blank" rel="noreferrer">
21
+ <span class="card-title center">🚀 Quickstart</span>
22
+ </a>
23
+ - <a class="card-link" href="examples/" target="_blank" rel="noreferrer">
24
+ <span class="card-title center">📚 Examples</span>
25
+ </a>
26
+ - <a class="card-link" href="contribute/" target="_blank" rel="noreferrer">
27
+ <span class="card-title center">🤝 Contribute</span>
28
+ </a>
21
29
  </div>
22
30
 
23
31
  ## How does it work?
@@ -33,7 +41,7 @@ by combining components, and fit it to one or multiple observed spectra using Ba
33
41
  Basically, the use of `JAX` as backend allows our models to be differentiable and computable on accelerators, and `numpyro`
34
42
  gives access to appropriate samplers such as the No U-Turn Sampler (NUTS) and Hamiltonian Monte Carlo (HMC).
35
43
 
36
- ## Citation
44
+ ## Citation
37
45
 
38
46
  If you use `jaxspec` in your research, please consider citing the following article
39
47
 
@@ -7,24 +7,9 @@
7
7
 
8
8
  ## Fitter classes
9
9
 
10
- ::: jaxspec.fit.BayesianModelFitter
11
- options:
12
- show_root_heading: true
13
- show_root_full_path: false
14
- show_root_toc_entry: true
15
- heading_level: 3
16
-
17
-
18
10
  ::: jaxspec.fit.MCMCFitter
19
11
  options:
20
12
  show_root_heading: true
21
13
  show_root_full_path: false
22
14
  show_root_toc_entry: true
23
15
  heading_level: 3
24
-
25
- ::: jaxspec.fit.NSFitter
26
- options:
27
- show_root_heading: true
28
- show_root_full_path: false
29
- show_root_toc_entry: true
30
- heading_level: 3
@@ -87,7 +87,7 @@ theme:
87
87
  plugins:
88
88
  - search
89
89
  - autorefs
90
- - typeset
90
+ # - typeset
91
91
  - mkdocs-jupyter:
92
92
  include_source: True
93
93
  ignore_h1_titles: True
@@ -172,6 +172,7 @@ extra_css:
172
172
  - css/material.css
173
173
  - css/mkdocstrings.css
174
174
  - css/xarray.css
175
+ - css/extra.css
175
176
 
176
177
  extra_javascript:
177
178
  - javascripts/mathjax.js
@@ -1,36 +1,36 @@
1
1
  [project]
2
2
  name = "jaxspec"
3
- version = "0.2.2dev"
3
+ version = "0.3.1"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = [{ name = "sdupourque", email = "sdupourque@irap.omp.eu" }]
6
6
  requires-python = ">=3.10,<3.13"
7
7
  readme = "README.md"
8
8
  license = "MIT"
9
9
  dependencies = [
10
- "jax>=0.5.0,<0.6",
11
- "numpy<2.0.0",
10
+ "jax>=0.5.0,<0.7",
11
+ "numpy<3.0.0",
12
12
  "pandas>=2.2.0,<3",
13
- "astropy>=6.0.0,<7",
14
- "numpyro>=0.17.0,<0.18",
13
+ "astropy>=6.0.0,<8",
14
+ "numpyro>=0.17.0,<0.20",
15
15
  "networkx~=3.1",
16
16
  "matplotlib>=3.8.0,<4",
17
- "arviz>=0.17.1,<0.21.0",
17
+ "arviz>=0.17.1,<0.23.0",
18
18
  "chainconsumer>=1.1.2,<2",
19
19
  "simpleeval>=0.9.13,<1.1.0",
20
20
  "cmasher>=1.6.3,<2",
21
21
  "jaxopt>=0.8.3,<0.9",
22
22
  "tinygp>=0.3.0,<0.4",
23
23
  "seaborn>=0.13.1,<0.14",
24
- "sparse>=0.15.4,<0.16",
24
+ "sparse>0.15",
25
25
  "optimistix>=0.0.10,<0.0.11",
26
- "scipy<1.15",
27
- "mendeleev>=0.15,<0.20",
26
+ "scipy<1.16",
27
+ "mendeleev>=0.15,<1.2",
28
28
  "jaxns>=2.6.7,<3",
29
29
  "pooch>=1.8.2,<2",
30
30
  "interpax>=0.3.5,<0.4",
31
31
  "watermark>=2.4.3,<3",
32
32
  "catppuccin>=2.3.4,<3",
33
- "flax>=0.10.3,<0.11",
33
+ "flax>0.10.5",
34
34
  ]
35
35
 
36
36
  [project.urls]
@@ -58,7 +58,7 @@ test = [
58
58
  ]
59
59
  dev = [
60
60
  "pre-commit>=3.5,<5.0",
61
- "ruff>=0.2.1,<0.10.0",
61
+ "ruff>=0.2.1,<0.15.0",
62
62
  "jupyterlab>=4.0.7,<5",
63
63
  "notebook>=7.0.6,<8",
64
64
  "ipywidgets>=8.1.1,<9",
@@ -59,8 +59,8 @@ def _plot_poisson_data_with_error(
59
59
  y,
60
60
  xerr=np.abs(x_bins - np.sqrt(x_bins[0] * x_bins[1])),
61
61
  yerr=[
62
- y - y_low,
63
- y_high - y,
62
+ np.maximum(y - y_low, 0),
63
+ np.maximum(y_high - y, 0),
64
64
  ],
65
65
  color=color,
66
66
  linestyle=linestyle,
@@ -149,13 +149,13 @@ def _compute_effective_area(
149
149
  mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
150
150
  mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
151
151
  e_grid = np.linspace(*xbins, 10)
152
- interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
152
+ interpolated_arf = np.interp(e_grid.value, mid_bins_arf.value, obsconf.area)
153
153
  integrated_arf = (
154
- trapezoid(interpolated_arf, x=e_grid, axis=0)
154
+ trapezoid(interpolated_arf, x=e_grid.value, axis=0)
155
155
  / (
156
156
  np.abs(
157
157
  xbins[1] - xbins[0]
158
- ) # Must fold in abs because some units reverse the ordering of the bins
158
+ ).value # Must fold in abs because some units reverse the ordering of the bins
159
159
  )
160
160
  * u.cm**2
161
161
  )
@@ -42,6 +42,11 @@ V = TypeVar("V")
42
42
  T = TypeVar("T")
43
43
 
44
44
 
45
+ def auto_in_axes(pytree, axis=0):
46
+ """Return a pytree of 0/None depending on whether the leaf is batched."""
47
+ return jax.tree.map(lambda x: axis if (hasattr(x, "ndim") and x.ndim > 0) else None, pytree)
48
+
49
+
45
50
  class FitResult:
46
51
  """
47
52
  Container for the result of a fit using any ModelFitter class.
@@ -54,17 +59,17 @@ class FitResult:
54
59
  inference_data: az.InferenceData,
55
60
  background_model: BackgroundModel = None,
56
61
  ):
57
- self.model = bayesian_fitter.model
62
+ self.model = bayesian_fitter.spectral_model
58
63
  self.bayesian_fitter = bayesian_fitter
59
64
  self.inference_data = inference_data
60
- self.obsconfs = bayesian_fitter.observation_container
65
+ self.obsconfs = bayesian_fitter._observation_container
61
66
  self.background_model = background_model
62
67
 
63
68
  # Add the model used in fit to the metadata
64
69
  for group in self.inference_data.groups():
65
70
  group_name = group.split("/")[-1]
66
71
  metadata = getattr(self.inference_data, group_name).attrs
67
- metadata["model"] = str(self.model)
72
+ # metadata["model"] = str(self.model)
68
73
  # TODO : Store metadata about observations used in the fitting process
69
74
 
70
75
  @property
@@ -78,6 +83,7 @@ class FitResult:
78
83
  def _ppc_folded_branches(self, obs_id):
79
84
  obs = self.obsconfs[obs_id]
80
85
 
86
+ # Slice the parameters corresponding to the current ObsID
81
87
  if len(next(iter(self.input_parameters.values())).shape) > 2:
82
88
  idx = list(self.obsconfs.keys()).index(obs_id)
83
89
  obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
@@ -85,7 +91,7 @@ class FitResult:
85
91
  else:
86
92
  obs_parameters = self.input_parameters
87
93
 
88
- if self.bayesian_fitter.sparse:
94
+ if self.bayesian_fitter.settings.get("sparse", False):
89
95
  transfer_matrix = BCOO.from_scipy_sparse(
90
96
  obs.transfer_matrix.data.to_scipy_sparse().tocsr()
91
97
  )
@@ -98,6 +104,7 @@ class FitResult:
98
104
  flux_func = jax.jit(
99
105
  jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
100
106
  )
107
+
101
108
  convolve_func = jax.jit(
102
109
  jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
103
110
  )
@@ -115,7 +122,9 @@ class FitResult:
115
122
 
116
123
  samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
117
124
 
118
- total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
125
+ total_shape = tuple(
126
+ posterior.sizes[d] for d in posterior.coords if not (("obs" in d) or ("bkg" in d))
127
+ )
119
128
 
120
129
  posterior = {key: posterior[key].data for key in posterior.data_vars}
121
130
 
@@ -124,13 +133,14 @@ class FitResult:
124
133
 
125
134
  for key, value in input_parameters.items():
126
135
  module, parameter = key.rsplit("_", 1)
136
+ key_to_search = f"mod/~/{module}_{parameter}"
127
137
 
128
- if f"{module}_{parameter}" in posterior.keys():
138
+ if key_to_search in posterior.keys():
129
139
  # We add as extra dimension as there might be different values per observation
130
- if posterior[f"{module}_{parameter}"].shape == samples_shape:
131
- to_set = posterior[f"{module}_{parameter}"][..., None]
140
+ if posterior[key_to_search].shape == samples_shape:
141
+ to_set = posterior[key_to_search][..., None]
132
142
  else:
133
- to_set = posterior[f"{module}_{parameter}"]
143
+ to_set = posterior[key_to_search]
134
144
 
135
145
  input_parameters[f"{module}_{parameter}"] = to_set
136
146
 
@@ -299,7 +309,7 @@ class FitResult:
299
309
 
300
310
  return value
301
311
 
302
- def to_chain(self, name: str) -> Chain:
312
+ def to_chain(self, name: str, parameter_kind="mod") -> Chain:
303
313
  """
304
314
  Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
305
315
 
@@ -308,9 +318,7 @@ class FitResult:
308
318
  """
309
319
 
310
320
  keys_to_drop = [
311
- key
312
- for key in self.inference_data.posterior.keys()
313
- if (key.startswith("_") or key.startswith("bkg"))
321
+ key for key in self.inference_data.posterior.keys() if not key.startswith("mod")
314
322
  ]
315
323
 
316
324
  reduced_id = az.extract(
@@ -338,6 +346,8 @@ class FitResult:
338
346
 
339
347
  df = pd.concat(df_list, axis=1)
340
348
 
349
+ df = df.rename(columns=lambda x: x.split("/~/")[-1])
350
+
341
351
  return Chain(samples=df, name=name)
342
352
 
343
353
  @property
@@ -450,7 +460,7 @@ class FitResult:
450
460
  legend_labels = []
451
461
 
452
462
  count = az.extract(
453
- self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
463
+ self.inference_data, var_names=f"obs/~/{obs_id}", group="posterior_predictive"
454
464
  ).values.T
455
465
 
456
466
  xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
@@ -465,7 +475,9 @@ class FitResult:
465
475
  case "photon_flux_density":
466
476
  denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
467
477
 
468
- y_samples = (count * u.ct / denominator).to(y_units)
478
+ y_samples = count * u.ct / denominator
479
+
480
+ y_samples = y_samples.to(y_units)
469
481
 
470
482
  y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
471
483
  obsconf.folded_counts.data, denominator, y_units
@@ -491,8 +503,8 @@ class FitResult:
491
503
  alpha=0.7,
492
504
  )
493
505
 
494
- lowest_y = y_observed.min()
495
- highest_y = y_observed.max()
506
+ lowest_y = np.nanmin(y_observed)
507
+ highest_y = np.nanmax(y_observed)
496
508
 
497
509
  legend_plots.append((true_data_plot,))
498
510
  legend_labels.append("Observed")
@@ -522,7 +534,10 @@ class FitResult:
522
534
  count.reshape((count.shape[0] * count.shape[1], -1))
523
535
  * u.ct
524
536
  / denominator
525
- ).to(y_units)
537
+ )
538
+
539
+ y_samples = y_samples.to(y_units)
540
+
526
541
  component_plot = _plot_binned_samples_with_error(
527
542
  ax[0],
528
543
  xbins.value,
@@ -545,7 +560,7 @@ class FitResult:
545
560
  if self.background_model is None
546
561
  else az.extract(
547
562
  self.inference_data,
548
- var_names=f"bkg_{obs_id}",
563
+ var_names=f"bkg/~/{obs_id}",
549
564
  group="posterior_predictive",
550
565
  ).values.T
551
566
  )
@@ -577,18 +592,18 @@ class FitResult:
577
592
  alpha=0.7,
578
593
  )
579
594
 
580
- lowest_y = min(lowest_y, y_observed_bkg.min())
581
- highest_y = max(highest_y, y_observed_bkg.max())
595
+ # lowest_y = np.nanmin(lowest_y.min, np.nanmin(y_observed_bkg.value).astype(float))
596
+ # highest_y = np.nanmax(highest_y.value.astype(float), np.nanmax(y_observed_bkg.value).astype(float))
582
597
 
583
598
  legend_plots.append((true_bkg_plot,))
584
599
  legend_labels.append("Observed (bkg)")
585
600
  legend_plots += model_bkg_plot
586
601
  legend_labels.append("Model (bkg)")
587
602
 
588
- max_residuals = np.max(np.abs(residual_samples))
603
+ max_residuals = min(3.5, np.nanmax(np.abs(residual_samples)))
589
604
 
590
605
  ax[0].loglog()
591
- ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
606
+ ax[1].set_ylim(-np.nanmax([3.5, max_residuals]), +np.nanmax([3.5, max_residuals]))
592
607
  ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
593
608
  ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
594
609
 
@@ -635,9 +650,9 @@ class FitResult:
635
650
 
636
651
  fig.align_ylabels()
637
652
  plt.subplots_adjust(hspace=0.0)
653
+ fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
638
654
  fig.tight_layout()
639
655
  figure_list.append(fig)
640
- fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
641
656
  # fig.show()
642
657
 
643
658
  plt.tight_layout()
@@ -651,9 +666,9 @@ class FitResult:
651
666
  """
652
667
 
653
668
  consumer = ChainConsumer()
654
- consumer.add_chain(self.to_chain(self.model.to_string()))
669
+ consumer.add_chain(self.to_chain("Model"))
655
670
 
656
- return consumer.analysis.get_latex_table(caption="Results of the fit", label="tab:results")
671
+ return consumer.analysis.get_latex_table(caption="Fit result", label="tab:results")
657
672
 
658
673
  def plot_corner(
659
674
  self,
@@ -85,13 +85,20 @@ class ObsConfiguration(xr.Dataset):
85
85
 
86
86
  from .util import data_path_finder
87
87
 
88
- arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(pha_path)
88
+ arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(
89
+ pha_path,
90
+ require_arf=(arf_path is None) and (arf_path != ""),
91
+ require_rmf=rmf_path is None,
92
+ require_bkg=bkg_path is None,
93
+ )
89
94
 
90
95
  arf_path = arf_path_default if arf_path is None else arf_path
91
96
  rmf_path = rmf_path_default if rmf_path is None else rmf_path
92
97
  bkg_path = bkg_path_default if bkg_path is None else bkg_path
93
98
 
94
- instrument = Instrument.from_ogip_file(rmf_path, arf_path=arf_path)
99
+ instrument = Instrument.from_ogip_file(
100
+ rmf_path, arf_path=arf_path if arf_path != "" else None
101
+ )
95
102
  observation = Observation.from_pha_file(pha_path, bkg_path=bkg_path)
96
103
 
97
104
  return cls.from_instrument(
@@ -141,7 +148,6 @@ class ObsConfiguration(xr.Dataset):
141
148
  transfer_matrix = grouping @ (redistribution * area * exposure)
142
149
 
143
150
  # Exclude bins out of the considered energy range, and bins without contribution from the RMF
144
-
145
151
  row_idx = (e_min > low_energy) & (e_max < high_energy) & (grouping.sum(axis=1) > 0)
146
152
  col_idx = (e_min_unfolded > 0) & (redistribution.sum(axis=0) > 0)
147
153
 
@@ -164,7 +164,9 @@ class Observation(xr.Dataset):
164
164
  """
165
165
  from .util import data_path_finder
166
166
 
167
- arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
167
+ arf_path, rmf_path, bkg_path_default = data_path_finder(
168
+ pha_path, require_arf=False, require_rmf=False, require_bkg=False
169
+ )
168
170
  bkg_path = bkg_path_default if bkg_path is None else bkg_path
169
171
 
170
172
  pha = DataPHA.from_file(pha_path)
@@ -109,7 +109,7 @@ class DataPHA:
109
109
  raise ValueError("No QUALITY column found in the PHA file.")
110
110
 
111
111
  if "BACKSCAL" in header:
112
- backscal = header["BACKSCAL"] * np.ones_like(data["CHANNEL"])
112
+ backscal = header["BACKSCAL"] * np.ones_like(data["CHANNEL"], dtype=float)
113
113
  elif "BACKSCAL" in data.colnames:
114
114
  backscal = data["BACKSCAL"]
115
115
  else:
@@ -138,7 +138,14 @@ class DataPHA:
138
138
  "flags": flags,
139
139
  }
140
140
 
141
- return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
141
+ if "COUNTS" in data.colnames:
142
+ counts = data["COUNTS"]
143
+ elif "RATE" in data.colnames:
144
+ counts = data["RATE"] * header["EXPOSURE"]
145
+ else:
146
+ raise ValueError("No COUNTS or RATE column found in the PHA file.")
147
+
148
+ return cls(data["CHANNEL"], counts, header["EXPOSURE"], **kwargs)
142
149
 
143
150
 
144
151
  class DataARF:
@@ -228,12 +228,17 @@ def fakeit_for_multiple_parameters(
228
228
  return fakeits[0] if len(fakeits) == 1 else fakeits
229
229
 
230
230
 
231
- def data_path_finder(pha_path: str) -> tuple[str | None, str | None, str | None]:
231
+ def data_path_finder(
232
+ pha_path: str, require_arf: bool = True, require_rmf: bool = True, require_bkg: bool = False
233
+ ) -> tuple[str | None, str | None, str | None]:
232
234
  """
233
235
  Function which tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
234
236
 
235
237
  Parameters:
236
238
  pha_path: The PHA file path.
239
+ require_arf: Whether to raise an error if the ARF file is not found.
240
+ require_rmf: Whether to raise an error if the RMF file is not found.
241
+ require_bkg: Whether to raise an error if the BKG file is not found.
237
242
 
238
243
  Returns:
239
244
  arf_path: The ARF file path.
@@ -241,23 +246,24 @@ def data_path_finder(pha_path: str) -> tuple[str | None, str | None, str | None]
241
246
  bkg_path: The BKG file path.
242
247
  """
243
248
 
244
- def find_path(file_name: str, directory: str) -> str | None:
245
- if file_name.lower() != "none" and file_name != "":
246
- return find_file_or_compressed_in_dir(file_name, directory)
247
- else:
248
- return None
249
+ def find_path(file_name: str, directory: str, raise_err: bool = True) -> str | None:
250
+ if raise_err:
251
+ if file_name.lower() != "none" and file_name != "":
252
+ return find_file_or_compressed_in_dir(file_name, directory, raise_err)
253
+
254
+ return None
249
255
 
250
256
  header = fits.getheader(pha_path, "SPECTRUM")
251
257
  directory = str(Path(pha_path).parent)
252
258
 
253
- arf_path = find_path(header.get("ANCRFILE", "none"), directory)
254
- rmf_path = find_path(header.get("RESPFILE", "none"), directory)
255
- bkg_path = find_path(header.get("BACKFILE", "none"), directory)
259
+ arf_path = find_path(header.get("ANCRFILE", "none"), directory, require_arf)
260
+ rmf_path = find_path(header.get("RESPFILE", "none"), directory, require_rmf)
261
+ bkg_path = find_path(header.get("BACKFILE", "none"), directory, require_bkg)
256
262
 
257
263
  return arf_path, rmf_path, bkg_path
258
264
 
259
265
 
260
- def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
266
+ def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path, raise_err: bool) -> str:
261
267
  """
262
268
  Try to find a file or its .gz compressed version in a given directory and return
263
269
  the full path of the file.
@@ -275,5 +281,5 @@ def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> s
275
281
  if file.suffix == ".gz":
276
282
  return str(file)
277
283
 
278
- else:
284
+ elif raise_err:
279
285
  raise FileNotFoundError(f"Can't find {path}(.gz) in {directory}.")