jaxspec 0.3.2__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxspec-0.3.2 → jaxspec-0.3.3}/PKG-INFO +2 -6
- {jaxspec-0.3.2 → jaxspec-0.3.3}/mkdocs.yml +1 -1
- {jaxspec-0.3.2 → jaxspec-0.3.3}/pyproject.toml +4 -8
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/analysis/results.py +61 -23
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/data/instrument.py +2 -1
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/data/util.py +2 -2
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/fit/_fitter.py +13 -5
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/abc.py +4 -3
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/util/online_storage.py +1 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/conftest.py +10 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_fakeit.py +13 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_mcmc.py +4 -1
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_results.py +9 -3
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_xspec.py +39 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.dockerignore +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.github/dependabot.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.github/workflows/documentation-links.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.github/workflows/publish.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.github/workflows/test-and-coverage.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.gitignore +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.pre-commit-config.yaml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.python-version +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/.readthedocs.yaml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/CODE_OF_CONDUCT.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/Dockerfile +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/LICENSE.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/README.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/codecov.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/contribute/index.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/contribute/internal.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/contribute/xspec.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/css/extra.css +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/css/material.css +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/css/mkdocstrings.css +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/css/xarray.css +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/dev/index.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/background.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/build_model.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/fakeits.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/fitting_example.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/index.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/background_comparison.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/background_gp.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/background_spectral.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/fakeits.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/fitting_example_corner.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/fitting_example_ppc.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/model.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/rmf.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/subtract_background.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/examples/statics/subtract_background_with_errors.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/faq/cookbook.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/faq/index.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/faq/statics/cstat_vs_chi2.png +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/frontpage/installation.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/index.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/javascripts/mathjax.js +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/logo/logo_small.svg +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/logo/xifu_mini.svg +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/abundance.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/additive.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/background.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/data.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/fitting.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/instrument.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/integrate.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/model.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/multiplicative.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/references/results.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/runtime/diagram.txt +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/runtime/result_table.txt +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/runtime/various_model_graphs/diagram.txt +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/runtime/various_model_graphs/model_complex.txt +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/runtime/various_model_graphs/model_simple.txt +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/theory/background.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/theory/bayesian_inference.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/theory/index.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/docs/theory/instrument.md +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/analysis/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/analysis/_plot.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/analysis/compare.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/data/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/data/obsconf.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/data/observation.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/data/ogip.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/experimental/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/experimental/interpolator.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/experimental/interpolator_jax.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/experimental/intrument_models.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/experimental/nested_sampler.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/experimental/tabulated.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/fit/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/fit/_bayesian_model.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/fit/_build_model.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/_graph_util.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/additive.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/background.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/instrument.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/list.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/model/multiplicative.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/scripts/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/scripts/debug.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/util/__init__.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/util/abundance.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/util/integrate.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/util/misc.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/src/jaxspec/util/typing.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/data_files.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/data_hash.yml +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_background.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_bayesian_model.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_bayesian_model_building.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_instruments.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_integrate.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_misc.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_models.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_observation.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_repr.py +0 -0
- {jaxspec-0.3.2 → jaxspec-0.3.3}/tests/test_xspec_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
|
|
@@ -16,19 +16,15 @@ Requires-Dist: cmasher<2,>=1.6.3
|
|
|
16
16
|
Requires-Dist: flax>0.10.5
|
|
17
17
|
Requires-Dist: interpax<0.4,>=0.3.5
|
|
18
18
|
Requires-Dist: jax<0.7,>=0.5.0
|
|
19
|
-
Requires-Dist: jaxns<3,>=2.6.7
|
|
20
|
-
Requires-Dist: jaxopt<0.9,>=0.8.3
|
|
21
19
|
Requires-Dist: matplotlib<4,>=3.8.0
|
|
22
20
|
Requires-Dist: mendeleev<1.2,>=0.15
|
|
23
21
|
Requires-Dist: networkx~=3.1
|
|
24
22
|
Requires-Dist: numpy<3.0.0
|
|
25
23
|
Requires-Dist: numpyro<0.20,>=0.17.0
|
|
26
|
-
Requires-Dist: optimistix<0.0.12,>=0.0.10
|
|
27
24
|
Requires-Dist: pandas<3,>=2.2.0
|
|
28
25
|
Requires-Dist: pooch<2,>=1.8.2
|
|
29
26
|
Requires-Dist: scipy<1.16
|
|
30
|
-
Requires-Dist: seaborn
|
|
31
|
-
Requires-Dist: simpleeval<1.1.0,>=0.9.13
|
|
27
|
+
Requires-Dist: seaborn>=0.13.2
|
|
32
28
|
Requires-Dist: sparse>0.15
|
|
33
29
|
Requires-Dist: tinygp<0.4,>=0.3.0
|
|
34
30
|
Requires-Dist: watermark<3,>=2.4.3
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "jaxspec"
|
|
3
|
-
version = "0.3.
|
|
3
|
+
version = "0.3.3"
|
|
4
4
|
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
|
|
5
5
|
authors = [{ name = "sdupourque", email = "sdupourque@irap.omp.eu" }]
|
|
6
6
|
requires-python = ">=3.10,<3.13"
|
|
@@ -16,21 +16,17 @@ dependencies = [
|
|
|
16
16
|
"matplotlib>=3.8.0,<4",
|
|
17
17
|
"arviz>=0.17.1,<0.23.0",
|
|
18
18
|
"chainconsumer>=1.1.2,<2",
|
|
19
|
-
"simpleeval>=0.9.13,<1.1.0",
|
|
20
19
|
"cmasher>=1.6.3,<2",
|
|
21
|
-
"jaxopt>=0.8.3,<0.9",
|
|
22
20
|
"tinygp>=0.3.0,<0.4",
|
|
23
|
-
"seaborn>=0.13.1,<0.14",
|
|
24
21
|
"sparse>0.15",
|
|
25
|
-
"optimistix>=0.0.10,<0.0.12",
|
|
26
22
|
"scipy<1.16",
|
|
27
23
|
"mendeleev>=0.15,<1.2",
|
|
28
|
-
"jaxns>=2.6.7,<3",
|
|
29
24
|
"pooch>=1.8.2,<2",
|
|
30
25
|
"interpax>=0.3.5,<0.4",
|
|
31
26
|
"watermark>=2.4.3,<3",
|
|
32
27
|
"catppuccin>=2.3.4,<3",
|
|
33
28
|
"flax>0.10.5",
|
|
29
|
+
"seaborn>=0.13.2",
|
|
34
30
|
]
|
|
35
31
|
|
|
36
32
|
[project.urls]
|
|
@@ -44,14 +40,14 @@ jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
|
|
|
44
40
|
docs = [
|
|
45
41
|
"mkdocs>=1.6.1,<2",
|
|
46
42
|
"mkdocs-material>=9.4.6,<10",
|
|
47
|
-
"mkdocstrings[python]>=0.24,<
|
|
43
|
+
"mkdocstrings[python]>=0.24,<1.1",
|
|
48
44
|
"mkdocs-jupyter>=0.25.0,<0.26",
|
|
49
45
|
]
|
|
50
46
|
test = [
|
|
51
47
|
"chex>=0.1.83,<0.2",
|
|
52
48
|
"mktestdocs>=0.2.1,<0.3",
|
|
53
49
|
"coverage>=7.3.2,<8",
|
|
54
|
-
"pytest-cov>=4.1,<
|
|
50
|
+
"pytest-cov>=4.1,<8.0",
|
|
55
51
|
"flake8>=7.0.0,<8",
|
|
56
52
|
"pytest>=8.0.0,<9",
|
|
57
53
|
"testbook>=0.4.2,<0.5",
|
|
@@ -77,8 +77,9 @@ class FitResult:
|
|
|
77
77
|
r"""
|
|
78
78
|
Convergence of the chain as computed by the $\hat{R}$ statistic.
|
|
79
79
|
"""
|
|
80
|
+
rhat = az.rhat(self.inference_data)
|
|
80
81
|
|
|
81
|
-
return
|
|
82
|
+
return bool((rhat.to_array() < 1.01).all())
|
|
82
83
|
|
|
83
84
|
def _ppc_folded_branches(self, obs_id):
|
|
84
85
|
obs = self.obsconfs[obs_id]
|
|
@@ -167,6 +168,7 @@ class FitResult:
|
|
|
167
168
|
e_max: float,
|
|
168
169
|
unit: Unit = u.photon / u.cm**2 / u.s,
|
|
169
170
|
register: bool = False,
|
|
171
|
+
n_points: int = 100,
|
|
170
172
|
) -> ArrayLike:
|
|
171
173
|
"""
|
|
172
174
|
Compute the unfolded photon flux in a given energy band. The flux is then added to
|
|
@@ -177,6 +179,7 @@ class FitResult:
|
|
|
177
179
|
e_max: The upper bound of the energy band in observer frame.
|
|
178
180
|
unit: The unit of the photon flux.
|
|
179
181
|
register: Whether to register the flux with the other posterior parameters.
|
|
182
|
+
n_points: The number of points to use for computing the unfolded spectrum.
|
|
180
183
|
|
|
181
184
|
!!! warning
|
|
182
185
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
@@ -188,18 +191,25 @@ class FitResult:
|
|
|
188
191
|
def vectorized_flux(*pars):
|
|
189
192
|
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
190
193
|
return self.model.photon_flux(
|
|
191
|
-
parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=
|
|
194
|
+
parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=n_points
|
|
192
195
|
)[0]
|
|
193
196
|
|
|
194
197
|
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
195
198
|
flux = vectorized_flux(*flat_tree)
|
|
196
|
-
conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
|
|
197
|
-
value = flux * conversion_factor
|
|
199
|
+
conversion_factor = float((u.photon / u.cm**2 / u.s).to(unit))
|
|
200
|
+
value = np.asarray(flux * conversion_factor)
|
|
198
201
|
|
|
199
202
|
if register:
|
|
200
|
-
self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
201
|
-
|
|
202
|
-
|
|
203
|
+
self.inference_data.posterior[f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
204
|
+
xr.DataArray(
|
|
205
|
+
value,
|
|
206
|
+
dims=self.inference_data.posterior.dims,
|
|
207
|
+
coords={
|
|
208
|
+
"chain": self.inference_data.posterior.chain,
|
|
209
|
+
"draw": self.inference_data.posterior.draw,
|
|
210
|
+
},
|
|
211
|
+
name=f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}",
|
|
212
|
+
)
|
|
203
213
|
)
|
|
204
214
|
|
|
205
215
|
return value
|
|
@@ -210,6 +220,7 @@ class FitResult:
|
|
|
210
220
|
e_max: float,
|
|
211
221
|
unit: Unit = u.erg / u.cm**2 / u.s,
|
|
212
222
|
register: bool = False,
|
|
223
|
+
n_points: int = 100,
|
|
213
224
|
) -> ArrayLike:
|
|
214
225
|
"""
|
|
215
226
|
Compute the unfolded energy flux in a given energy band. The flux is then added to
|
|
@@ -220,6 +231,7 @@ class FitResult:
|
|
|
220
231
|
e_max: The upper bound of the energy band in observer frame.
|
|
221
232
|
unit: The unit of the energy flux.
|
|
222
233
|
register: Whether to register the flux with the other posterior parameters.
|
|
234
|
+
n_points: The number of points to use for computing the unfolded spectrum.
|
|
223
235
|
|
|
224
236
|
!!! warning
|
|
225
237
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
@@ -231,18 +243,25 @@ class FitResult:
|
|
|
231
243
|
def vectorized_flux(*pars):
|
|
232
244
|
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
233
245
|
return self.model.energy_flux(
|
|
234
|
-
parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=
|
|
246
|
+
parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=n_points
|
|
235
247
|
)[0]
|
|
236
248
|
|
|
237
249
|
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
238
250
|
flux = vectorized_flux(*flat_tree)
|
|
239
|
-
conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
|
|
240
|
-
value = flux * conversion_factor
|
|
251
|
+
conversion_factor = float((u.keV / u.cm**2 / u.s).to(unit))
|
|
252
|
+
value = np.asarray(flux * conversion_factor)
|
|
241
253
|
|
|
242
254
|
if register:
|
|
243
|
-
self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
244
|
-
|
|
245
|
-
|
|
255
|
+
self.inference_data.posterior[f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
256
|
+
xr.DataArray(
|
|
257
|
+
value,
|
|
258
|
+
dims=self.inference_data.posterior.dims,
|
|
259
|
+
coords={
|
|
260
|
+
"chain": self.inference_data.posterior.chain,
|
|
261
|
+
"draw": self.inference_data.posterior.draw,
|
|
262
|
+
},
|
|
263
|
+
name=f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}",
|
|
264
|
+
)
|
|
246
265
|
)
|
|
247
266
|
|
|
248
267
|
return value
|
|
@@ -257,6 +276,7 @@ class FitResult:
|
|
|
257
276
|
cosmology: Cosmology = Planck18,
|
|
258
277
|
unit: Unit = u.erg / u.s,
|
|
259
278
|
register: bool = False,
|
|
279
|
+
n_points=100,
|
|
260
280
|
) -> ArrayLike:
|
|
261
281
|
"""
|
|
262
282
|
Compute the luminosity of the source specifying its redshift. The luminosity is then added to
|
|
@@ -294,17 +314,26 @@ class FitResult:
|
|
|
294
314
|
parameters_pytree,
|
|
295
315
|
jnp.asarray([e_min]) * (1 + redshift),
|
|
296
316
|
jnp.asarray([e_max]) * (1 + redshift),
|
|
297
|
-
n_points=
|
|
317
|
+
n_points=n_points,
|
|
298
318
|
)[0]
|
|
299
319
|
|
|
300
320
|
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
301
321
|
flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
|
|
302
|
-
value =
|
|
322
|
+
value = np.asarray(
|
|
323
|
+
(flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
|
|
324
|
+
)
|
|
303
325
|
|
|
304
326
|
if register:
|
|
305
|
-
self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
|
|
306
|
-
|
|
307
|
-
|
|
327
|
+
self.inference_data.posterior[f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}"] = (
|
|
328
|
+
xr.DataArray(
|
|
329
|
+
value,
|
|
330
|
+
dims=self.inference_data.posterior.dims,
|
|
331
|
+
coords={
|
|
332
|
+
"chain": self.inference_data.posterior.chain,
|
|
333
|
+
"draw": self.inference_data.posterior.draw,
|
|
334
|
+
},
|
|
335
|
+
name=f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}",
|
|
336
|
+
)
|
|
308
337
|
)
|
|
309
338
|
|
|
310
339
|
return value
|
|
@@ -315,10 +344,13 @@ class FitResult:
|
|
|
315
344
|
|
|
316
345
|
Parameters:
|
|
317
346
|
name: The name of the chain.
|
|
347
|
+
parameter_kind: The kind of parameters to keep.
|
|
318
348
|
"""
|
|
319
349
|
|
|
320
350
|
keys_to_drop = [
|
|
321
|
-
key
|
|
351
|
+
key
|
|
352
|
+
for key in self.inference_data.posterior.keys()
|
|
353
|
+
if not key.startswith(parameter_kind)
|
|
322
354
|
]
|
|
323
355
|
|
|
324
356
|
reduced_id = az.extract(
|
|
@@ -403,6 +435,7 @@ class FitResult:
|
|
|
403
435
|
title: str | None = None,
|
|
404
436
|
figsize: tuple[float, float] = (6, 6),
|
|
405
437
|
x_lims: tuple[float, float] | None = None,
|
|
438
|
+
rescale_background: bool = False,
|
|
406
439
|
) -> list[plt.Figure]:
|
|
407
440
|
r"""
|
|
408
441
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
@@ -423,6 +456,7 @@ class FitResult:
|
|
|
423
456
|
title: The title of the plot.
|
|
424
457
|
figsize: The size of the figure.
|
|
425
458
|
x_lims: The limits of the x-axis.
|
|
459
|
+
rescale_background: Whether to rescale the background model to the data with backscal ratio.
|
|
426
460
|
|
|
427
461
|
Returns:
|
|
428
462
|
A list of matplotlib figures for each observation in the model.
|
|
@@ -573,10 +607,14 @@ class FitResult:
|
|
|
573
607
|
)
|
|
574
608
|
)
|
|
575
609
|
|
|
610
|
+
rescale_background_factor = (
|
|
611
|
+
obsconf.folded_backratio.data if rescale_background else 1.0
|
|
612
|
+
)
|
|
613
|
+
|
|
576
614
|
model_bkg_plot = _plot_binned_samples_with_error(
|
|
577
615
|
ax[0],
|
|
578
616
|
xbins.value,
|
|
579
|
-
y_samples_bkg.value,
|
|
617
|
+
y_samples_bkg.value * rescale_background_factor,
|
|
580
618
|
color=BACKGROUND_COLOR,
|
|
581
619
|
alpha_envelope=alpha_envelope,
|
|
582
620
|
n_sigmas=n_sigmas,
|
|
@@ -585,9 +623,9 @@ class FitResult:
|
|
|
585
623
|
true_bkg_plot = _plot_poisson_data_with_error(
|
|
586
624
|
ax[0],
|
|
587
625
|
xbins.value,
|
|
588
|
-
y_observed_bkg.value,
|
|
589
|
-
y_observed_bkg_low.value,
|
|
590
|
-
y_observed_bkg_high.value,
|
|
626
|
+
y_observed_bkg.value * rescale_background_factor,
|
|
627
|
+
y_observed_bkg_low.value * rescale_background_factor,
|
|
628
|
+
y_observed_bkg_high.value * rescale_background_factor,
|
|
591
629
|
color=BACKGROUND_DATA_COLOR,
|
|
592
630
|
alpha=0.7,
|
|
593
631
|
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import sparse
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import xarray as xr
|
|
@@ -92,7 +93,7 @@ class Instrument(xr.Dataset):
|
|
|
92
93
|
|
|
93
94
|
else:
|
|
94
95
|
specresp = rmf.matrix.sum(axis=0)
|
|
95
|
-
rmf.sparse_matrix
|
|
96
|
+
rmf.sparse_matrix = sparse.COO( rmf.matrix / specresp )
|
|
96
97
|
|
|
97
98
|
return cls.from_matrix(
|
|
98
99
|
rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
|
|
@@ -152,11 +152,11 @@ def forward_model_with_multiple_inputs(
|
|
|
152
152
|
transfer_matrix = BCOO.from_scipy_sparse(
|
|
153
153
|
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
154
154
|
)
|
|
155
|
+
expected_counts = transfer_matrix @ flux_func(parameters).T
|
|
155
156
|
|
|
156
157
|
else:
|
|
157
158
|
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
158
|
-
|
|
159
|
-
expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
|
|
159
|
+
expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
|
|
160
160
|
|
|
161
161
|
# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
|
|
162
162
|
return jnp.clip(expected_counts, a_min=1e-6)
|
|
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
|
|
|
9
9
|
import numpyro
|
|
10
10
|
|
|
11
11
|
from jax import random
|
|
12
|
+
from jax.numpy import concatenate
|
|
12
13
|
from jax.random import PRNGKey
|
|
13
14
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, SVI, Predictive, Trace_ELBO
|
|
14
15
|
from numpyro.infer.autoguide import AutoMultivariateNormal
|
|
@@ -52,9 +53,18 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
52
53
|
)
|
|
53
54
|
|
|
54
55
|
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
56
|
+
if len(log_likelihood.keys()) > 1:
|
|
57
|
+
log_likelihood["full"] = concatenate([ll for _, ll in log_likelihood.items()], axis=1)
|
|
58
|
+
log_likelihood["obs/~/all"] = concatenate(
|
|
59
|
+
[ll for k, ll in log_likelihood.items() if "obs" in k], axis=1
|
|
60
|
+
)
|
|
61
|
+
if self.background_model is not None:
|
|
62
|
+
log_likelihood["bkg/~/all"] = concatenate(
|
|
63
|
+
[ll for k, ll in log_likelihood.items() if "bkg" in k], axis=1
|
|
64
|
+
)
|
|
55
65
|
|
|
56
66
|
seeded_model = numpyro.handlers.substitute(
|
|
57
|
-
numpyro.handlers.seed(numpyro_model, keys[
|
|
67
|
+
numpyro.handlers.seed(numpyro_model, keys[2]),
|
|
58
68
|
substitute_fn=numpyro.infer.init_to_sample,
|
|
59
69
|
)
|
|
60
70
|
|
|
@@ -108,12 +118,10 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
108
118
|
predictive_parameters = []
|
|
109
119
|
|
|
110
120
|
for key, value in self._observation_container.items():
|
|
121
|
+
predictive_parameters.append(f"obs/~/{key}")
|
|
111
122
|
if self.background_model is not None:
|
|
112
|
-
predictive_parameters.append(f"obs/~/{key}")
|
|
113
123
|
predictive_parameters.append(f"bkg/~/{key}")
|
|
114
124
|
# predictive_parameters.append(f"ins/~/{key}")
|
|
115
|
-
else:
|
|
116
|
-
predictive_parameters.append(f"obs/~/{key}")
|
|
117
125
|
# predictive_parameters.append(f"ins/~/{key}")
|
|
118
126
|
|
|
119
127
|
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
@@ -247,7 +255,7 @@ class VIFitter(BayesianModelFitter):
|
|
|
247
255
|
|
|
248
256
|
svi = SVI(bayesian_model, guide, optimizer, loss=loss)
|
|
249
257
|
|
|
250
|
-
keys = random.split(random.PRNGKey(rng_key),
|
|
258
|
+
keys = random.split(random.PRNGKey(rng_key), 2)
|
|
251
259
|
svi_result = svi.run(keys[0], num_steps)
|
|
252
260
|
params = svi_result.params
|
|
253
261
|
|
|
@@ -372,9 +372,10 @@ class AdditiveComponent(ModelComponent):
|
|
|
372
372
|
continuum = self.continuum(energy)
|
|
373
373
|
integrated_continuum = self.integrated_continuum(e_low, e_high)
|
|
374
374
|
|
|
375
|
-
return
|
|
376
|
-
continuum * energy**2, jnp.log(energy), axis=-1
|
|
377
|
-
|
|
375
|
+
return (
|
|
376
|
+
jsp.integrate.trapezoid(continuum * energy**2, jnp.log(energy), axis=-1)
|
|
377
|
+
+ integrated_continuum * (e_high + e_low) / 2.0
|
|
378
|
+
)
|
|
378
379
|
|
|
379
380
|
@partial(jax.jit, static_argnums=0, static_argnames="n_points")
|
|
380
381
|
def photon_flux(self, params, e_low, e_high, n_points=2):
|
|
@@ -25,4 +25,5 @@ table_manager = pooch.create(
|
|
|
25
25
|
"example_data/NGC7793_ULX4/MOS2.arf": "sha256:a126ff5a95a5f4bb93ed846944cf411d6e1c448626cb73d347e33324663d8b3f",
|
|
26
26
|
"example_data/NGC7793_ULX4/PNbackground_spectrum.fits": "sha256:55e017e0c19b324245fef049dff2a7a2e49b9a391667ca9c4f667c4f683b1f49",
|
|
27
27
|
},
|
|
28
|
+
retry_if_failed=10,
|
|
28
29
|
)
|
|
@@ -70,6 +70,7 @@ pooch_dataset = pooch.create(
|
|
|
70
70
|
base_url="https://github.com/HEACIT/curated-test-data/raw/main/",
|
|
71
71
|
path=str(data_directory),
|
|
72
72
|
registry=data_hash,
|
|
73
|
+
retry_if_failed=10,
|
|
73
74
|
)
|
|
74
75
|
|
|
75
76
|
for file in data_hash.keys():
|
|
@@ -112,6 +113,15 @@ def get_individual_mcmc_results(obs_model_prior):
|
|
|
112
113
|
return [MCMCFitter(model, prior, obsconf).fit(num_samples=5000) for obsconf in obsconfs]
|
|
113
114
|
|
|
114
115
|
|
|
116
|
+
@pytest.fixture(scope="session")
|
|
117
|
+
def get_failed_mcmc_results(obs_model_prior):
|
|
118
|
+
obsconfs, model, prior = obs_model_prior
|
|
119
|
+
|
|
120
|
+
return [
|
|
121
|
+
MCMCFitter(model, prior, obsconf).fit(num_warmup=10, num_samples=10) for obsconf in obsconfs
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
|
|
115
125
|
@pytest.fixture(scope="session")
|
|
116
126
|
def get_joint_mcmc_result(obs_model_prior):
|
|
117
127
|
obsconfs, model, prior = obs_model_prior
|
|
@@ -93,6 +93,19 @@ def test_fakeits_parallel(obsconfs, model, sharded_parameters):
|
|
|
93
93
|
chex.assert_type(spectra, int)
|
|
94
94
|
|
|
95
95
|
|
|
96
|
+
def test_fakeits_sparsify(obsconfs, model, sharded_parameters):
|
|
97
|
+
obsconf = obsconfs[0]
|
|
98
|
+
spectra = fakeit_for_multiple_parameters(
|
|
99
|
+
obsconf, model, sharded_parameters, apply_stat=False, sparsify_matrix=False
|
|
100
|
+
)
|
|
101
|
+
chex.assert_type(spectra, float)
|
|
102
|
+
|
|
103
|
+
spectra = fakeit_for_multiple_parameters(
|
|
104
|
+
obsconf, model, sharded_parameters, apply_stat=False, sparsify_matrix=True
|
|
105
|
+
)
|
|
106
|
+
chex.assert_type(spectra, float)
|
|
107
|
+
|
|
108
|
+
|
|
96
109
|
def test_fakeits_multiple_observation(obsconfs, model, multidimensional_parameters):
|
|
97
110
|
obsconf = obsconfs[0]
|
|
98
111
|
spectra = fakeit_for_multiple_parameters(
|
|
@@ -3,10 +3,13 @@ import pytest
|
|
|
3
3
|
from jaxspec.fit import BayesianModel
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result):
|
|
6
|
+
def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result, get_failed_mcmc_results):
|
|
7
7
|
for result in get_individual_mcmc_results + get_joint_mcmc_result:
|
|
8
8
|
assert result.converged
|
|
9
9
|
|
|
10
|
+
for result in get_failed_mcmc_results:
|
|
11
|
+
assert not result.converged
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
def test_ns(obs_model_prior):
|
|
12
15
|
NSFitter = pytest.importorskip("jaxspec.fit.NSFitter")
|
|
@@ -54,14 +54,18 @@ def test_posterior_photon_flux(get_joint_mcmc_result):
|
|
|
54
54
|
result = get_joint_mcmc_result[0]
|
|
55
55
|
e_min, e_max = 0.7, 1.2
|
|
56
56
|
result.photon_flux(e_min, e_max, register=True)
|
|
57
|
-
assert f"photon_flux_{e_min:.1f}_{e_max:.1f}" in list(
|
|
57
|
+
assert f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}" in list(
|
|
58
|
+
result.inference_data.posterior.keys()
|
|
59
|
+
)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
def test_posterior_energy_flux(get_joint_mcmc_result):
|
|
61
63
|
result = get_joint_mcmc_result[0]
|
|
62
64
|
e_min, e_max = 0.7, 1.2
|
|
63
65
|
result.energy_flux(e_min, e_max, register=True)
|
|
64
|
-
assert f"energy_flux_{e_min:.1f}_{e_max:.1f}" in list(
|
|
66
|
+
assert f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}" in list(
|
|
67
|
+
result.inference_data.posterior.keys()
|
|
68
|
+
)
|
|
65
69
|
|
|
66
70
|
|
|
67
71
|
def test_posterior_luminosity(get_joint_mcmc_result):
|
|
@@ -76,4 +80,6 @@ def test_posterior_luminosity(get_joint_mcmc_result):
|
|
|
76
80
|
|
|
77
81
|
result.luminosity(e_min, e_max, redshift=0.1, register=True)
|
|
78
82
|
|
|
79
|
-
assert f"luminosity_{e_min:.1f}_{e_max:.1f}" in list(
|
|
83
|
+
assert f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}" in list(
|
|
84
|
+
result.inference_data.posterior.keys()
|
|
85
|
+
)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
|
|
3
|
+
import astropy.units as u
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pytest
|
|
5
6
|
|
|
7
|
+
from jaxspec.model.additive import Powerlaw
|
|
6
8
|
from jaxspec.util.online_storage import table_manager
|
|
7
9
|
|
|
8
10
|
xspec = pytest.importorskip("xspec")
|
|
@@ -78,3 +80,40 @@ def test_bins(load_xspec_data, load_jaxspec_data):
|
|
|
78
80
|
assert np.isclose(
|
|
79
81
|
folding.in_energies, xspec_in_energies
|
|
80
82
|
).all(), f"The unfolded channel energy bins are not the same as XSPEC {file_pha}"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_flux_computation():
|
|
86
|
+
xspec.AllData.clear()
|
|
87
|
+
xspec.AllModels.clear()
|
|
88
|
+
|
|
89
|
+
xspec.AllData.dummyrsp(lowE=0.2, highE=1.7, nBins=10_000)
|
|
90
|
+
m = xspec.Model("powerlaw")
|
|
91
|
+
m.powerlaw.PhoIndex = 2.0
|
|
92
|
+
m.powerlaw.norm = 1.0
|
|
93
|
+
|
|
94
|
+
xspec.AllModels.calcFlux("0.5 1.5")
|
|
95
|
+
|
|
96
|
+
phflux_xspec = m.flux[3] # ph/cm^2/s
|
|
97
|
+
eflux_xspec = m.flux[0] # erg/cm^2/s
|
|
98
|
+
|
|
99
|
+
factor = (1 * u.keV).to(u.erg).value
|
|
100
|
+
phflux_jaxspec = Powerlaw().photon_flux(
|
|
101
|
+
{"powerlaw_1_norm": 1.0, "powerlaw_1_alpha": 2.0}, e_low=0.5, e_high=1.5, n_points=10_000
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
eflux_jaxspec = (
|
|
105
|
+
Powerlaw().energy_flux(
|
|
106
|
+
{"powerlaw_1_norm": 1.0, "powerlaw_1_alpha": 2.0},
|
|
107
|
+
e_low=0.5,
|
|
108
|
+
e_high=1.5,
|
|
109
|
+
n_points=10_000,
|
|
110
|
+
)
|
|
111
|
+
* factor
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
assert np.isclose(
|
|
115
|
+
phflux_xspec, phflux_jaxspec
|
|
116
|
+
), f"Mismatch between XSPEC and jaxspec on photon flux, got {phflux_xspec} and {phflux_jaxspec}"
|
|
117
|
+
assert np.isclose(
|
|
118
|
+
eflux_xspec, eflux_jaxspec
|
|
119
|
+
), f"Mismatch between XSPEC and jaxspec on energy flux, got {eflux_xspec} and {eflux_jaxspec}"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|