jaxspec 0.1.3__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxspec-0.1.3 → jaxspec-0.1.4}/PKG-INFO +26 -5
- {jaxspec-0.1.3 → jaxspec-0.1.4}/README.md +21 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/pyproject.toml +7 -7
- jaxspec-0.1.4/src/jaxspec/_fit/_build_model.py +140 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/analysis/results.py +26 -13
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/observation.py +51 -7
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/util.py +1 -1
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/fit.py +69 -120
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/model/background.py +95 -87
- jaxspec-0.1.4/src/jaxspec/scripts/__init__.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/LICENSE.md +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/__init__.py +0 -0
- {jaxspec-0.1.3/src/jaxspec/analysis → jaxspec-0.1.4/src/jaxspec/_fit}/__init__.py +0 -0
- {jaxspec-0.1.3/src/jaxspec/model → jaxspec-0.1.4/src/jaxspec/analysis}/__init__.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/analysis/_plot.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/analysis/compare.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/__init__.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/grouping.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/instrument.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/obsconf.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/data/ogip.py +0 -0
- {jaxspec-0.1.3/src/jaxspec/scripts → jaxspec-0.1.4/src/jaxspec/model}/__init__.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/model/abc.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/model/additive.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/model/list.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/model/multiplicative.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/scripts/debug.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/util/__init__.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/util/abundance.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/util/integrate.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/util/online_storage.py +0 -0
- {jaxspec-0.1.3 → jaxspec-0.1.4}/src/jaxspec/util/typing.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Home-page: https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
License: MIT
|
|
@@ -11,7 +11,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
-
Requires-Dist: arviz (>=0.17.1,<0.
|
|
14
|
+
Requires-Dist: arviz (>=0.17.1,<0.21.0)
|
|
15
15
|
Requires-Dist: astropy (>=6.0.0,<7.0.0)
|
|
16
16
|
Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
|
|
17
17
|
Requires-Dist: cmasher (>=1.6.3,<2.0.0)
|
|
@@ -23,17 +23,17 @@ Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
|
|
|
23
23
|
Requires-Dist: jaxns (<2.6)
|
|
24
24
|
Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
|
|
25
25
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
|
26
|
-
Requires-Dist: mendeleev (>=0.15,<0.
|
|
26
|
+
Requires-Dist: mendeleev (>=0.15,<0.19)
|
|
27
27
|
Requires-Dist: networkx (>=3.1,<4.0)
|
|
28
28
|
Requires-Dist: numpy (<2.0.0)
|
|
29
29
|
Requires-Dist: numpyro (>=0.15.3,<0.16.0)
|
|
30
|
-
Requires-Dist: optimistix (>=0.0.7,<0.0.
|
|
30
|
+
Requires-Dist: optimistix (>=0.0.7,<0.0.10)
|
|
31
31
|
Requires-Dist: pandas (>=2.2.0,<3.0.0)
|
|
32
32
|
Requires-Dist: pooch (>=1.8.2,<2.0.0)
|
|
33
33
|
Requires-Dist: pyzmq (<27)
|
|
34
34
|
Requires-Dist: scipy (<1.15)
|
|
35
35
|
Requires-Dist: seaborn (>=0.13.1,<0.14.0)
|
|
36
|
-
Requires-Dist: simpleeval (>=0.9.13,<
|
|
36
|
+
Requires-Dist: simpleeval (>=0.9.13,<1.1.0)
|
|
37
37
|
Requires-Dist: sparse (>=0.15.1,<0.16.0)
|
|
38
38
|
Requires-Dist: tinygp (>=0.3.0,<0.4.0)
|
|
39
39
|
Requires-Dist: watermark (>=2.4.3,<3.0.0)
|
|
@@ -78,3 +78,24 @@ Once the environment is set up, you can install jaxspec directly from pypi
|
|
|
78
78
|
pip install jaxspec --upgrade
|
|
79
79
|
```
|
|
80
80
|
|
|
81
|
+
## Citation
|
|
82
|
+
|
|
83
|
+
If you use `jaxspec` in your research, please consider citing the following article
|
|
84
|
+
|
|
85
|
+
```
|
|
86
|
+
@ARTICLE{2024A&A...690A.317D,
|
|
87
|
+
author = {{Dupourqu{\'e}}, S. and {Barret}, D. and {Diez}, C.~M. and {Guillot}, S. and {Quintin}, E.},
|
|
88
|
+
title = "{jaxspec: A fast and robust Python library for X-ray spectral fitting}",
|
|
89
|
+
journal = {\aap},
|
|
90
|
+
keywords = {methods: data analysis, methods: statistical, X-rays: general},
|
|
91
|
+
year = 2024,
|
|
92
|
+
month = oct,
|
|
93
|
+
volume = {690},
|
|
94
|
+
eid = {A317},
|
|
95
|
+
pages = {A317},
|
|
96
|
+
doi = {10.1051/0004-6361/202451736},
|
|
97
|
+
adsurl = {https://ui.adsabs.harvard.edu/abs/2024A&A...690A.317D},
|
|
98
|
+
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
|
|
99
|
+
}
|
|
100
|
+
```
|
|
101
|
+
|
|
@@ -35,3 +35,24 @@ Once the environment is set up, you can install jaxspec directly from pypi
|
|
|
35
35
|
```
|
|
36
36
|
pip install jaxspec --upgrade
|
|
37
37
|
```
|
|
38
|
+
|
|
39
|
+
## Citation
|
|
40
|
+
|
|
41
|
+
If you use `jaxspec` in your research, please consider citing the following article
|
|
42
|
+
|
|
43
|
+
```
|
|
44
|
+
@ARTICLE{2024A&A...690A.317D,
|
|
45
|
+
author = {{Dupourqu{\'e}}, S. and {Barret}, D. and {Diez}, C.~M. and {Guillot}, S. and {Quintin}, E.},
|
|
46
|
+
title = "{jaxspec: A fast and robust Python library for X-ray spectral fitting}",
|
|
47
|
+
journal = {\aap},
|
|
48
|
+
keywords = {methods: data analysis, methods: statistical, X-rays: general},
|
|
49
|
+
year = 2024,
|
|
50
|
+
month = oct,
|
|
51
|
+
volume = {690},
|
|
52
|
+
eid = {A317},
|
|
53
|
+
pages = {A317},
|
|
54
|
+
doi = {10.1051/0004-6361/202451736},
|
|
55
|
+
adsurl = {https://ui.adsabs.harvard.edu/abs/2024A&A...690A.317D},
|
|
56
|
+
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
|
|
57
|
+
}
|
|
58
|
+
```
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "jaxspec"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.4"
|
|
4
4
|
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
|
|
5
5
|
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -20,18 +20,18 @@ numpyro = "^0.15.3"
|
|
|
20
20
|
dm-haiku = "^0.0.12"
|
|
21
21
|
networkx = "^3.1"
|
|
22
22
|
matplotlib = "^3.8.0"
|
|
23
|
-
arviz = ">=0.17.1,<0.
|
|
23
|
+
arviz = ">=0.17.1,<0.21.0"
|
|
24
24
|
chainconsumer = "^1.1.2"
|
|
25
|
-
simpleeval = "
|
|
25
|
+
simpleeval = ">=0.9.13,<1.1.0"
|
|
26
26
|
cmasher = "^1.6.3"
|
|
27
27
|
gpjax = "^0.8.0"
|
|
28
28
|
jaxopt = "^0.8.1"
|
|
29
29
|
tinygp = "^0.3.0"
|
|
30
30
|
seaborn = "^0.13.1"
|
|
31
31
|
sparse = "^0.15.1"
|
|
32
|
-
optimistix = "
|
|
32
|
+
optimistix = ">=0.0.7,<0.0.10"
|
|
33
33
|
scipy = "<1.15"
|
|
34
|
-
mendeleev = ">=0.15,<0.
|
|
34
|
+
mendeleev = ">=0.15,<0.19"
|
|
35
35
|
pyzmq = "<27"
|
|
36
36
|
jaxns = "<2.6"
|
|
37
37
|
pooch = "^1.8.2"
|
|
@@ -57,8 +57,8 @@ testbook = "^0.4.2"
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
[tool.poetry.group.dev.dependencies]
|
|
60
|
-
pre-commit = "
|
|
61
|
-
ruff = ">=0.2.1,<0.
|
|
60
|
+
pre-commit = ">=3.5,<5.0"
|
|
61
|
+
ruff = ">=0.2.1,<0.8.0"
|
|
62
62
|
jupyterlab = "^4.0.7"
|
|
63
63
|
notebook = "^7.0.6"
|
|
64
64
|
ipywidgets = "^8.1.1"
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import numpyro
|
|
3
|
+
import haiku as hk
|
|
4
|
+
import numpy as np
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from typing import Callable
|
|
7
|
+
from jax.experimental.sparse import BCOO
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
from numpyro.distributions import Poisson
|
|
10
|
+
from jax.typing import ArrayLike
|
|
11
|
+
from numpyro.distributions import Distribution
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from ..model.abc import SpectralModel
|
|
16
|
+
from ..data import ObsConfiguration
|
|
17
|
+
from ..util.typing import PriorDictModel, PriorDictType
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CountForwardModel(hk.Module):
|
|
22
|
+
"""
|
|
23
|
+
A haiku module which allows to build the function that simulates the measured counts
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# TODO: It has no point of being a haiku module, it should be a simple function
|
|
27
|
+
|
|
28
|
+
def __init__(self, model: 'SpectralModel', folding: 'ObsConfiguration', sparse=False):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.model = model
|
|
31
|
+
self.energies = jnp.asarray(folding.in_energies)
|
|
32
|
+
|
|
33
|
+
if (
|
|
34
|
+
sparse
|
|
35
|
+
): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
36
|
+
self.transfer_matrix = BCOO.from_scipy_sparse(
|
|
37
|
+
folding.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
else:
|
|
41
|
+
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
42
|
+
|
|
43
|
+
def __call__(self, parameters):
|
|
44
|
+
"""
|
|
45
|
+
Compute the count functions for a given observation.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
|
|
49
|
+
|
|
50
|
+
return jnp.clip(expected_counts, a_min=1e-6)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def forward_model(
|
|
54
|
+
model: 'SpectralModel',
|
|
55
|
+
parameters,
|
|
56
|
+
obs_configuration: 'ObsConfiguration',
|
|
57
|
+
sparse=False,
|
|
58
|
+
):
|
|
59
|
+
|
|
60
|
+
energies = np.asarray(obs_configuration.in_energies)
|
|
61
|
+
|
|
62
|
+
if sparse:
|
|
63
|
+
# folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
64
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
65
|
+
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
70
|
+
|
|
71
|
+
expected_counts = transfer_matrix @ model.photon_flux(parameters, *energies)
|
|
72
|
+
|
|
73
|
+
# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
|
|
74
|
+
return jnp.clip(expected_counts, a_min=1e-6)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def build_numpyro_model_for_single_obs(
|
|
78
|
+
obs,
|
|
79
|
+
model,
|
|
80
|
+
background_model,
|
|
81
|
+
name: str = "",
|
|
82
|
+
sparse: bool = False,
|
|
83
|
+
) -> Callable:
|
|
84
|
+
"""
|
|
85
|
+
Build a numpyro model for a given observation and spectral model.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def numpyro_model(prior_params, observed=True):
|
|
89
|
+
|
|
90
|
+
# Return the expected countrate for a set of parameters
|
|
91
|
+
obs_model = jax.jit(lambda par: forward_model(model, par, obs, sparse=sparse))
|
|
92
|
+
countrate = obs_model(prior_params)
|
|
93
|
+
|
|
94
|
+
# Handle the background model
|
|
95
|
+
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
96
|
+
bkg_countrate = background_model.numpyro_model(
|
|
97
|
+
obs, model, name="bkg_" + name, observed=observed
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Trying to fit a background model but no background is linked to this observation"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
else:
|
|
106
|
+
bkg_countrate = 0.0
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Register the observed value
|
|
110
|
+
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
|
|
111
|
+
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
112
|
+
numpyro.sample(
|
|
113
|
+
"obs_" + name,
|
|
114
|
+
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
115
|
+
obs=obs.folded_counts.data if observed else None,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return numpyro_model
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def build_prior(prior: 'PriorDictType', expand_shape: tuple = (), prefix=""):
|
|
122
|
+
"""
|
|
123
|
+
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
124
|
+
Must be used within a numpyro model.
|
|
125
|
+
"""
|
|
126
|
+
parameters = dict(hk.data_structures.to_haiku_dict(prior))
|
|
127
|
+
|
|
128
|
+
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
129
|
+
if isinstance(sample, Distribution):
|
|
130
|
+
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
|
|
131
|
+
|
|
132
|
+
elif isinstance(sample, ArrayLike):
|
|
133
|
+
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
134
|
+
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return parameters
|
|
@@ -4,6 +4,7 @@ from collections.abc import Mapping
|
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
5
5
|
|
|
6
6
|
import arviz as az
|
|
7
|
+
import astropy.cosmology.units as cu
|
|
7
8
|
import astropy.units as u
|
|
8
9
|
import jax
|
|
9
10
|
import jax.numpy as jnp
|
|
@@ -287,7 +288,8 @@ class FitResult:
|
|
|
287
288
|
self,
|
|
288
289
|
e_min: float,
|
|
289
290
|
e_max: float,
|
|
290
|
-
redshift: float | ArrayLike =
|
|
291
|
+
redshift: float | ArrayLike = None,
|
|
292
|
+
distance: float | ArrayLike = None,
|
|
291
293
|
observer_frame: bool = True,
|
|
292
294
|
cosmology: Cosmology = Planck18,
|
|
293
295
|
unit: Unit = u.erg / u.s,
|
|
@@ -310,6 +312,17 @@ class FitResult:
|
|
|
310
312
|
if not observer_frame:
|
|
311
313
|
raise NotImplementedError()
|
|
312
314
|
|
|
315
|
+
if redshift is None and distance is None:
|
|
316
|
+
raise ValueError("Either redshift or distance must be specified.")
|
|
317
|
+
|
|
318
|
+
if distance is not None:
|
|
319
|
+
if redshift is not None:
|
|
320
|
+
raise ValueError("Redshift must be None as a distance is specified.")
|
|
321
|
+
else:
|
|
322
|
+
redshift = distance.to(
|
|
323
|
+
cu.redshift, cu.redshift_distance(cosmology, kind="luminosity")
|
|
324
|
+
).value
|
|
325
|
+
|
|
313
326
|
@jax.jit
|
|
314
327
|
@jnp.vectorize
|
|
315
328
|
def vectorized_flux(*pars):
|
|
@@ -489,13 +502,13 @@ class FitResult:
|
|
|
489
502
|
|
|
490
503
|
match y_type:
|
|
491
504
|
case "counts":
|
|
492
|
-
y_units = u.
|
|
505
|
+
y_units = u.ct
|
|
493
506
|
case "countrate":
|
|
494
|
-
y_units = u.
|
|
507
|
+
y_units = u.ct / u.s
|
|
495
508
|
case "photon_flux":
|
|
496
|
-
y_units = u.
|
|
509
|
+
y_units = u.ct / u.cm**2 / u.s
|
|
497
510
|
case "photon_flux_density":
|
|
498
|
-
y_units = u.
|
|
511
|
+
y_units = u.ct / u.cm**2 / u.s / x_unit
|
|
499
512
|
case _:
|
|
500
513
|
raise ValueError(
|
|
501
514
|
f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
|
|
@@ -566,16 +579,16 @@ class FitResult:
|
|
|
566
579
|
case "photon_flux_density":
|
|
567
580
|
denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
|
|
568
581
|
|
|
569
|
-
y_samples = (count * u.
|
|
570
|
-
y_observed = (obsconf.folded_counts.data * u.
|
|
582
|
+
y_samples = (count * u.ct / denominator).to(y_units)
|
|
583
|
+
y_observed = (obsconf.folded_counts.data * u.ct / denominator).to(y_units)
|
|
571
584
|
y_observed_low = (
|
|
572
585
|
nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
|
|
573
|
-
* u.
|
|
586
|
+
* u.ct
|
|
574
587
|
/ denominator
|
|
575
588
|
).to(y_units)
|
|
576
589
|
y_observed_high = (
|
|
577
590
|
nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
|
|
578
|
-
* u.
|
|
591
|
+
* u.ct
|
|
579
592
|
/ denominator
|
|
580
593
|
).to(y_units)
|
|
581
594
|
|
|
@@ -611,18 +624,18 @@ class FitResult:
|
|
|
611
624
|
if self.background_model is not None:
|
|
612
625
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
613
626
|
ratio = obsconf.folded_backratio.data
|
|
614
|
-
y_samples_bkg = (bkg_count * u.
|
|
627
|
+
y_samples_bkg = (bkg_count * u.ct / (denominator * ratio)).to(y_units)
|
|
615
628
|
y_observed_bkg = (
|
|
616
|
-
obsconf.folded_background.data * u.
|
|
629
|
+
obsconf.folded_background.data * u.ct / (denominator * ratio)
|
|
617
630
|
).to(y_units)
|
|
618
631
|
y_observed_bkg_low = (
|
|
619
632
|
nbinom.ppf(percentile[0] / 100, obsconf.folded_background.data, 0.5)
|
|
620
|
-
* u.
|
|
633
|
+
* u.ct
|
|
621
634
|
/ (denominator * ratio)
|
|
622
635
|
).to(y_units)
|
|
623
636
|
y_observed_bkg_high = (
|
|
624
637
|
nbinom.ppf(percentile[1] / 100, obsconf.folded_background.data, 0.5)
|
|
625
|
-
* u.
|
|
638
|
+
* u.ct
|
|
626
639
|
/ (denominator * ratio)
|
|
627
640
|
).to(y_units)
|
|
628
641
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import xarray as xr
|
|
3
|
+
|
|
3
4
|
from .ogip import DataPHA
|
|
4
5
|
|
|
5
6
|
|
|
@@ -23,7 +24,16 @@ class Observation(xr.Dataset):
|
|
|
23
24
|
folded_background: xr.DataArray
|
|
24
25
|
"""The background counts, after grouping"""
|
|
25
26
|
|
|
26
|
-
__slots__ = (
|
|
27
|
+
__slots__ = (
|
|
28
|
+
"grouping",
|
|
29
|
+
"channel",
|
|
30
|
+
"quality",
|
|
31
|
+
"exposure",
|
|
32
|
+
"background",
|
|
33
|
+
"folded_background",
|
|
34
|
+
"counts",
|
|
35
|
+
"folded_counts",
|
|
36
|
+
)
|
|
27
37
|
|
|
28
38
|
_default_attributes = {"description": "X-ray observation dataset"}
|
|
29
39
|
|
|
@@ -46,7 +56,11 @@ class Observation(xr.Dataset):
|
|
|
46
56
|
background = np.zeros_like(counts, dtype=np.int64)
|
|
47
57
|
|
|
48
58
|
data_dict = {
|
|
49
|
-
"counts": (
|
|
59
|
+
"counts": (
|
|
60
|
+
["instrument_channel"],
|
|
61
|
+
np.asarray(counts, dtype=np.int64),
|
|
62
|
+
{"description": "Counts", "unit": "photons"},
|
|
63
|
+
),
|
|
50
64
|
"folded_counts": (
|
|
51
65
|
["folded_channel"],
|
|
52
66
|
np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64),
|
|
@@ -57,7 +71,11 @@ class Observation(xr.Dataset):
|
|
|
57
71
|
grouping,
|
|
58
72
|
{"description": "Grouping matrix."},
|
|
59
73
|
),
|
|
60
|
-
"quality": (
|
|
74
|
+
"quality": (
|
|
75
|
+
["instrument_channel"],
|
|
76
|
+
np.asarray(quality, dtype=np.int64),
|
|
77
|
+
{"description": "Quality flag."},
|
|
78
|
+
),
|
|
61
79
|
"exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}),
|
|
62
80
|
"backratio": (
|
|
63
81
|
["instrument_channel"],
|
|
@@ -84,20 +102,29 @@ class Observation(xr.Dataset):
|
|
|
84
102
|
return cls(
|
|
85
103
|
data_dict,
|
|
86
104
|
coords={
|
|
87
|
-
"channel": (
|
|
105
|
+
"channel": (
|
|
106
|
+
["instrument_channel"],
|
|
107
|
+
np.asarray(channel, dtype=np.int64),
|
|
108
|
+
{"description": "Channel number"},
|
|
109
|
+
),
|
|
88
110
|
"grouped_channel": (
|
|
89
111
|
["folded_channel"],
|
|
90
112
|
np.arange(len(grouping @ counts), dtype=np.int64),
|
|
91
113
|
{"description": "Channel number"},
|
|
92
114
|
),
|
|
93
115
|
},
|
|
94
|
-
attrs=cls._default_attributes
|
|
116
|
+
attrs=cls._default_attributes
|
|
117
|
+
if attributes is None
|
|
118
|
+
else attributes | cls._default_attributes,
|
|
95
119
|
)
|
|
96
120
|
|
|
97
121
|
@classmethod
|
|
98
122
|
def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
|
|
99
123
|
if bkg is not None:
|
|
100
|
-
backratio = np.nan_to_num(
|
|
124
|
+
backratio = np.nan_to_num(
|
|
125
|
+
(pha.backscal * pha.exposure * pha.areascal)
|
|
126
|
+
/ (bkg.backscal * bkg.exposure * bkg.areascal)
|
|
127
|
+
)
|
|
101
128
|
else:
|
|
102
129
|
backratio = np.ones_like(pha.counts)
|
|
103
130
|
|
|
@@ -114,6 +141,14 @@ class Observation(xr.Dataset):
|
|
|
114
141
|
|
|
115
142
|
@classmethod
|
|
116
143
|
def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
|
|
144
|
+
"""
|
|
145
|
+
Build an observation from a PHA file
|
|
146
|
+
|
|
147
|
+
Parameters:
|
|
148
|
+
pha_path : Path to the PHA file
|
|
149
|
+
bkg_path : Path to the background file
|
|
150
|
+
metadata : Additional metadata to add to the observation
|
|
151
|
+
"""
|
|
117
152
|
from .util import data_path_finder
|
|
118
153
|
|
|
119
154
|
arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
|
|
@@ -155,7 +190,16 @@ class Observation(xr.Dataset):
|
|
|
155
190
|
|
|
156
191
|
fig = plt.figure(figsize=(6, 6))
|
|
157
192
|
gs = fig.add_gridspec(
|
|
158
|
-
2,
|
|
193
|
+
2,
|
|
194
|
+
2,
|
|
195
|
+
width_ratios=(4, 1),
|
|
196
|
+
height_ratios=(1, 4),
|
|
197
|
+
left=0.1,
|
|
198
|
+
right=0.9,
|
|
199
|
+
bottom=0.1,
|
|
200
|
+
top=0.9,
|
|
201
|
+
wspace=0.05,
|
|
202
|
+
hspace=0.05,
|
|
159
203
|
)
|
|
160
204
|
ax = fig.add_subplot(gs[1, 0])
|
|
161
205
|
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
|
|
@@ -11,7 +11,7 @@ from astropy.io import fits
|
|
|
11
11
|
from numpy.typing import ArrayLike
|
|
12
12
|
from numpyro import handlers
|
|
13
13
|
|
|
14
|
-
from ..
|
|
14
|
+
from .._fit._build_model import CountForwardModel
|
|
15
15
|
from ..model.abc import SpectralModel
|
|
16
16
|
from ..util.online_storage import table_manager
|
|
17
17
|
from . import Instrument, ObsConfiguration, Observation
|
|
@@ -7,7 +7,6 @@ from functools import cached_property
|
|
|
7
7
|
from typing import Literal
|
|
8
8
|
|
|
9
9
|
import arviz as az
|
|
10
|
-
import haiku as hk
|
|
11
10
|
import jax
|
|
12
11
|
import jax.numpy as jnp
|
|
13
12
|
import matplotlib.pyplot as plt
|
|
@@ -15,17 +14,15 @@ import numpy as np
|
|
|
15
14
|
import numpyro
|
|
16
15
|
|
|
17
16
|
from jax import random
|
|
18
|
-
from jax.experimental.sparse import BCOO
|
|
19
17
|
from jax.random import PRNGKey
|
|
20
|
-
from jax.tree_util import tree_map
|
|
21
|
-
from jax.typing import ArrayLike
|
|
22
18
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
23
|
-
from numpyro.distributions import
|
|
19
|
+
from numpyro.distributions import Poisson, TransformedDistribution
|
|
24
20
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
25
21
|
from numpyro.infer.inspect import get_model_relations
|
|
26
22
|
from numpyro.infer.reparam import TransformReparam
|
|
27
23
|
from numpyro.infer.util import log_density
|
|
28
24
|
|
|
25
|
+
from ._fit._build_model import build_prior, forward_model
|
|
29
26
|
from .analysis._plot import _plot_poisson_data_with_error
|
|
30
27
|
from .analysis.results import FitResult
|
|
31
28
|
from .data import ObsConfiguration
|
|
@@ -34,101 +31,6 @@ from .model.background import BackgroundModel
|
|
|
34
31
|
from .util.typing import PriorDictModel, PriorDictType
|
|
35
32
|
|
|
36
33
|
|
|
37
|
-
def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
|
|
38
|
-
"""
|
|
39
|
-
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
40
|
-
Must be used within a numpyro model.
|
|
41
|
-
"""
|
|
42
|
-
parameters = dict(hk.data_structures.to_haiku_dict(prior))
|
|
43
|
-
|
|
44
|
-
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
45
|
-
if isinstance(sample, Distribution):
|
|
46
|
-
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
|
|
47
|
-
|
|
48
|
-
elif isinstance(sample, ArrayLike):
|
|
49
|
-
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
50
|
-
|
|
51
|
-
else:
|
|
52
|
-
raise ValueError(
|
|
53
|
-
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
return parameters
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def build_numpyro_model_for_single_obs(
|
|
60
|
-
obs: ObsConfiguration,
|
|
61
|
-
model: SpectralModel,
|
|
62
|
-
background_model: BackgroundModel,
|
|
63
|
-
name: str = "",
|
|
64
|
-
sparse: bool = False,
|
|
65
|
-
) -> Callable:
|
|
66
|
-
"""
|
|
67
|
-
Build a numpyro model for a given observation and spectral model.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
def numpyro_model(prior_params, observed=True):
|
|
71
|
-
# prior_params = build_prior(prior_distributions, name=name)
|
|
72
|
-
transformed_model = hk.without_apply_rng(
|
|
73
|
-
hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
77
|
-
bkg_countrate = background_model.numpyro_model(
|
|
78
|
-
obs, model, name="bkg_" + name, observed=observed
|
|
79
|
-
)
|
|
80
|
-
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
81
|
-
raise ValueError(
|
|
82
|
-
"Trying to fit a background model but no background is linked to this observation"
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
else:
|
|
86
|
-
bkg_countrate = 0.0
|
|
87
|
-
|
|
88
|
-
obs_model = jax.jit(lambda p: transformed_model.apply(None, p))
|
|
89
|
-
countrate = obs_model(prior_params)
|
|
90
|
-
|
|
91
|
-
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
|
|
92
|
-
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
93
|
-
numpyro.sample(
|
|
94
|
-
"obs_" + name,
|
|
95
|
-
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
96
|
-
obs=obs.folded_counts.data if observed else None,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
return numpyro_model
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class CountForwardModel(hk.Module):
|
|
103
|
-
"""
|
|
104
|
-
A haiku module which allows to build the function that simulates the measured counts
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False):
|
|
108
|
-
super().__init__()
|
|
109
|
-
self.model = model
|
|
110
|
-
self.energies = jnp.asarray(folding.in_energies)
|
|
111
|
-
|
|
112
|
-
if (
|
|
113
|
-
sparse
|
|
114
|
-
): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
115
|
-
self.transfer_matrix = BCOO.from_scipy_sparse(
|
|
116
|
-
folding.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
else:
|
|
120
|
-
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
121
|
-
|
|
122
|
-
def __call__(self, parameters):
|
|
123
|
-
"""
|
|
124
|
-
Compute the count functions for a given observation.
|
|
125
|
-
"""
|
|
126
|
-
|
|
127
|
-
expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
|
|
128
|
-
|
|
129
|
-
return jnp.clip(expected_counts, a_min=1e-6)
|
|
130
|
-
|
|
131
|
-
|
|
132
34
|
class BayesianModel:
|
|
133
35
|
"""
|
|
134
36
|
Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
|
|
@@ -157,7 +59,6 @@ class BayesianModel:
|
|
|
157
59
|
self.model = model
|
|
158
60
|
self._observations = observations
|
|
159
61
|
self.background_model = background_model
|
|
160
|
-
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
161
62
|
self.sparse = sparsify_matrix
|
|
162
63
|
|
|
163
64
|
if not callable(prior_distributions):
|
|
@@ -197,22 +98,50 @@ class BayesianModel:
|
|
|
197
98
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
198
99
|
"""
|
|
199
100
|
|
|
200
|
-
def
|
|
101
|
+
def numpyro_model(observed=True):
|
|
102
|
+
# Instantiate and register the parameters of the spectral model and the background
|
|
201
103
|
prior_params = self.prior_distributions_func()
|
|
202
104
|
|
|
203
105
|
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
204
|
-
for i, (
|
|
106
|
+
for i, (name, observation) in enumerate(self.observation_container.items()):
|
|
107
|
+
# Check that we can indeed fit a background
|
|
108
|
+
if (getattr(observation, "folded_background", None) is not None) and (
|
|
109
|
+
self.background_model is not None
|
|
110
|
+
):
|
|
111
|
+
# This call should register the parameter and observation of our background model
|
|
112
|
+
bkg_countrate = self.background_model.numpyro_model(
|
|
113
|
+
observation, name=name, observed=observed
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
elif (getattr(observation, "folded_background", None) is None) and (
|
|
117
|
+
self.background_model is not None
|
|
118
|
+
):
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Trying to fit a background model but no background is linked to this observation"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
else:
|
|
124
|
+
bkg_countrate = 0.0
|
|
125
|
+
|
|
205
126
|
# We expect that prior_params contains an array of parameters for each observation
|
|
206
127
|
# They can be identical or different for each observation
|
|
207
|
-
params =
|
|
128
|
+
params = jax.tree.map(lambda x: x[i], prior_params)
|
|
208
129
|
|
|
209
|
-
|
|
210
|
-
|
|
130
|
+
# Forward model the observation and get the associated countrate
|
|
131
|
+
obs_model = jax.jit(
|
|
132
|
+
lambda par: forward_model(self.model, par, observation, sparse=self.sparse)
|
|
211
133
|
)
|
|
134
|
+
obs_countrate = obs_model(params)
|
|
212
135
|
|
|
213
|
-
|
|
136
|
+
# Register the observation as an observed site
|
|
137
|
+
with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
|
|
138
|
+
numpyro.sample(
|
|
139
|
+
"obs_" + name,
|
|
140
|
+
Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
|
|
141
|
+
obs=observation.folded_counts.data if observed else None,
|
|
142
|
+
)
|
|
214
143
|
|
|
215
|
-
return
|
|
144
|
+
return numpyro_model
|
|
216
145
|
|
|
217
146
|
@cached_property
|
|
218
147
|
def transformed_numpyro_model(self) -> Callable:
|
|
@@ -352,7 +281,9 @@ class BayesianModel:
|
|
|
352
281
|
return fakeit(key, parameters)
|
|
353
282
|
|
|
354
283
|
def prior_predictive_coverage(
|
|
355
|
-
self,
|
|
284
|
+
self,
|
|
285
|
+
key: PRNGKey = PRNGKey(0),
|
|
286
|
+
num_samples: int = 1000,
|
|
356
287
|
):
|
|
357
288
|
"""
|
|
358
289
|
Check if the prior distribution include the observed data.
|
|
@@ -363,24 +294,36 @@ class BayesianModel:
|
|
|
363
294
|
|
|
364
295
|
for key, value in self.observation_container.items():
|
|
365
296
|
fig, axs = plt.subplots(
|
|
366
|
-
nrows=2, ncols=1, sharex=True, figsize=(
|
|
297
|
+
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
|
|
367
298
|
)
|
|
368
299
|
|
|
369
300
|
_plot_poisson_data_with_error(
|
|
370
301
|
axs[0],
|
|
371
302
|
value.out_energies,
|
|
372
303
|
value.folded_counts.values,
|
|
373
|
-
percentiles=
|
|
304
|
+
percentiles=(16, 84),
|
|
374
305
|
)
|
|
375
306
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
307
|
+
for i, (envelop_percentiles, color, alpha) in enumerate(
|
|
308
|
+
zip(
|
|
309
|
+
[(16, 86), (2.5, 97.5), (0.15, 99.85)],
|
|
310
|
+
["#03045e", "#0077b6", "#00b4d8"],
|
|
311
|
+
[0.5, 0.4, 0.3],
|
|
312
|
+
)
|
|
313
|
+
):
|
|
314
|
+
lower, upper = np.percentile(
|
|
315
|
+
posterior_observations["obs_" + key], envelop_percentiles, axis=0
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
axs[0].stairs(
|
|
319
|
+
upper,
|
|
320
|
+
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
|
|
321
|
+
baseline=lower,
|
|
322
|
+
alpha=alpha,
|
|
323
|
+
fill=True,
|
|
324
|
+
color=color,
|
|
325
|
+
label=rf"${1+i}\sigma$",
|
|
326
|
+
)
|
|
384
327
|
|
|
385
328
|
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
386
329
|
counts = posterior_observations["obs_" + key]
|
|
@@ -408,7 +351,9 @@ class BayesianModel:
|
|
|
408
351
|
axs[1].set_ylim(0, 100)
|
|
409
352
|
axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
|
|
410
353
|
axs[0].loglog()
|
|
354
|
+
axs[0].legend(loc="upper right")
|
|
411
355
|
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
356
|
+
plt.tight_layout()
|
|
412
357
|
plt.show()
|
|
413
358
|
|
|
414
359
|
|
|
@@ -513,7 +458,11 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
513
458
|
predictive_parameters
|
|
514
459
|
]
|
|
515
460
|
|
|
516
|
-
parameters = [
|
|
461
|
+
parameters = [
|
|
462
|
+
x
|
|
463
|
+
for x in inference_data.posterior.keys()
|
|
464
|
+
if not x.endswith("_base") or x.startswith("_")
|
|
465
|
+
]
|
|
517
466
|
inference_data.posterior = inference_data.posterior[parameters]
|
|
518
467
|
inference_data.prior = inference_data.prior[parameters]
|
|
519
468
|
|
|
@@ -1,22 +1,28 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
|
|
3
|
+
import jax
|
|
3
4
|
import jax.numpy as jnp
|
|
4
5
|
import numpyro
|
|
5
6
|
import numpyro.distributions as dist
|
|
6
7
|
|
|
7
8
|
from jax.scipy.integrate import trapezoid
|
|
9
|
+
from numpyro.distributions import Poisson
|
|
8
10
|
from tinygp import GaussianProcess, kernels
|
|
9
11
|
|
|
12
|
+
from .._fit._build_model import build_prior, forward_model
|
|
13
|
+
from ..util.typing import PriorDictModel
|
|
14
|
+
from .abc import SpectralModel
|
|
15
|
+
|
|
10
16
|
|
|
11
17
|
class BackgroundModel(ABC):
|
|
12
18
|
"""
|
|
13
|
-
|
|
19
|
+
Handles the background modelling in our spectra. This is handled in a separate class for now
|
|
14
20
|
since backgrounds can be phenomenological models fitted directly on the folded spectrum. This is not the case for
|
|
15
21
|
the source model, which is fitted on the unfolded spectrum. This might be changed later.
|
|
16
22
|
"""
|
|
17
23
|
|
|
18
24
|
@abstractmethod
|
|
19
|
-
def numpyro_model(self,
|
|
25
|
+
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
20
26
|
"""
|
|
21
27
|
Build the model for the background.
|
|
22
28
|
"""
|
|
@@ -25,7 +31,7 @@ class BackgroundModel(ABC):
|
|
|
25
31
|
|
|
26
32
|
class SubtractedBackground(BackgroundModel):
|
|
27
33
|
"""
|
|
28
|
-
|
|
34
|
+
Define a model where the observed background is simply subtracted from the observed.
|
|
29
35
|
|
|
30
36
|
!!! danger
|
|
31
37
|
|
|
@@ -35,93 +41,40 @@ class SubtractedBackground(BackgroundModel):
|
|
|
35
41
|
|
|
36
42
|
"""
|
|
37
43
|
|
|
38
|
-
def numpyro_model(self,
|
|
39
|
-
_, observed_counts =
|
|
40
|
-
numpyro.deterministic(f"{name}", observed_counts)
|
|
44
|
+
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
45
|
+
_, observed_counts = observation.out_energies, observation.folded_background.data
|
|
46
|
+
numpyro.deterministic(f"bkg_{name}", observed_counts)
|
|
41
47
|
|
|
42
|
-
return
|
|
48
|
+
return observed_counts
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
class BackgroundWithError(BackgroundModel):
|
|
46
52
|
"""
|
|
47
|
-
|
|
53
|
+
Define a model where the observed background is subtracted from the observed accounting for its intrinsic spread. It
|
|
48
54
|
fits a countrate for each background bin assuming a Poisson distribution.
|
|
49
|
-
|
|
50
|
-
!!! warning
|
|
51
|
-
This is the same as [`ConjugateBackground`][jaxspec.model.background.ConjugateBackground]
|
|
52
|
-
but slower since it performs the fit using MCMC instead of analytical solution.
|
|
53
55
|
"""
|
|
54
56
|
|
|
55
|
-
def numpyro_model(self, obs,
|
|
57
|
+
def numpyro_model(self, obs, name: str = "", observed=True):
|
|
58
|
+
# We can't use the build_prior_function method here because the parameter size varies
|
|
59
|
+
# with the current observation. It must be instantiated in place.
|
|
56
60
|
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
61
|
+
|
|
57
62
|
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
58
63
|
alpha = observed_counts + 1
|
|
59
64
|
beta = 1
|
|
60
|
-
countrate = numpyro.sample(f"{name}
|
|
65
|
+
countrate = numpyro.sample(f"_bkg_{name}_countrate", dist.Gamma(alpha, rate=beta))
|
|
61
66
|
|
|
62
|
-
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
67
|
+
with numpyro.plate(f"bkg_{name}_plate", len(observed_counts)):
|
|
63
68
|
numpyro.sample(
|
|
64
|
-
f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
69
|
+
f"bkg_{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
65
70
|
)
|
|
66
71
|
|
|
67
72
|
return countrate
|
|
68
73
|
|
|
69
74
|
|
|
70
|
-
'''
|
|
71
|
-
# TODO: Implement this class and sample it with Gibbs Sampling
|
|
72
|
-
|
|
73
|
-
class ConjugateBackground(BackgroundModel):
|
|
74
|
-
r"""
|
|
75
|
-
This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
|
|
76
|
-
distribution, we can analytically derive the posterior as a Negative binomial distribution.
|
|
77
|
-
|
|
78
|
-
$$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
|
|
79
|
-
p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
|
|
80
|
-
\right) $$
|
|
81
|
-
|
|
82
|
-
!!! info
|
|
83
|
-
Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
|
|
84
|
-
the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
|
|
85
|
-
$\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
|
|
86
|
-
counts are 0, and add a small scatter even if the measured background is effectively null.
|
|
87
|
-
|
|
88
|
-
??? abstract "References"
|
|
89
|
-
|
|
90
|
-
- https://en.wikipedia.org/wiki/Conjugate_prior
|
|
91
|
-
- https://www.acsu.buffalo.edu/~adamcunn/probability/gamma.html
|
|
92
|
-
- https://bayesiancomputationbook.com/markdown/chp_01.html?highlight=conjugate#conjugate-priors
|
|
93
|
-
- https://vioshyvo.github.io/Bayesian_inference/conjugate-distributions.html
|
|
94
|
-
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
98
|
-
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
99
|
-
# alpha = observed_counts + 1
|
|
100
|
-
# beta = 1
|
|
101
|
-
|
|
102
|
-
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
103
|
-
countrate = numpyro.sample(f"{name}", dist.Gamma(2 * observed_counts + 1, 2), obs=None)
|
|
104
|
-
|
|
105
|
-
return countrate
|
|
106
|
-
'''
|
|
107
|
-
|
|
108
|
-
"""
|
|
109
|
-
class SpectralBackgroundModel(BackgroundModel):
|
|
110
|
-
# I should pass the current spectral model as an argument to the background model
|
|
111
|
-
# In the numpyro model function
|
|
112
|
-
def __init__(self, model, prior):
|
|
113
|
-
self.model = model
|
|
114
|
-
self.prior = prior
|
|
115
|
-
|
|
116
|
-
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
117
|
-
#TODO : keep the sparsification from top model
|
|
118
|
-
transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
|
|
119
|
-
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
75
|
class GaussianProcessBackground(BackgroundModel):
|
|
123
76
|
"""
|
|
124
|
-
|
|
77
|
+
Define a Gaussian Process to model the background. The GP is built using the
|
|
125
78
|
[`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
|
|
126
79
|
"""
|
|
127
80
|
|
|
@@ -146,16 +99,7 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
146
99
|
self.n_nodes = n_nodes
|
|
147
100
|
self.kernel = kernel
|
|
148
101
|
|
|
149
|
-
def numpyro_model(self, obs,
|
|
150
|
-
"""
|
|
151
|
-
Build the model for the background.
|
|
152
|
-
|
|
153
|
-
Parameters:
|
|
154
|
-
energy: The energy bins lower and upper values (e_low, e_high).
|
|
155
|
-
observed_counts: The observed counts in each energy bin.
|
|
156
|
-
name: The name of the background model for parameters disambiguation.
|
|
157
|
-
observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
|
|
158
|
-
"""
|
|
102
|
+
def numpyro_model(self, obs, name: str = "", observed=True):
|
|
159
103
|
energy, observed_counts = obs.out_energies, obs.folded_background.data
|
|
160
104
|
|
|
161
105
|
if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
|
|
@@ -163,28 +107,92 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
163
107
|
"More nodes than channels in the observation associated with GaussianProcessBackground."
|
|
164
108
|
)
|
|
165
109
|
|
|
110
|
+
else:
|
|
111
|
+
observed_counts = jnp.asarray(observed_counts)
|
|
112
|
+
|
|
166
113
|
# The parameters of the GP model
|
|
167
|
-
mean = numpyro.sample(
|
|
168
|
-
|
|
169
|
-
|
|
114
|
+
mean = numpyro.sample(
|
|
115
|
+
f"_bkg_{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0)
|
|
116
|
+
)
|
|
117
|
+
sigma = numpyro.sample(f"_bkg_{name}_sigma", dist.HalfNormal(3.0))
|
|
118
|
+
rho = numpyro.sample(f"_bkg_{name}_rho", dist.HalfNormal(10.0))
|
|
170
119
|
|
|
171
120
|
# Set up the kernel and GP objects
|
|
172
121
|
kernel = sigma**2 * self.kernel(rho)
|
|
173
122
|
nodes = jnp.linspace(0, 1, self.n_nodes)
|
|
174
123
|
gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
|
|
175
124
|
|
|
176
|
-
log_rate = numpyro.sample(f"
|
|
125
|
+
log_rate = numpyro.sample(f"_bkg_{name}_log_rate_nodes", gp.numpyro_dist())
|
|
126
|
+
|
|
177
127
|
interp_count_rate = jnp.exp(
|
|
178
128
|
jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
|
|
179
129
|
)
|
|
180
130
|
count_rate = trapezoid(interp_count_rate, energy, axis=0)
|
|
181
131
|
|
|
182
132
|
# Finally, our observation model is Poisson
|
|
183
|
-
with numpyro.plate(
|
|
184
|
-
# TODO : change to Poisson Likelihood when there is no background model
|
|
185
|
-
# TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
|
|
133
|
+
with numpyro.plate("bkg_plate_" + name, len(observed_counts)):
|
|
186
134
|
numpyro.sample(
|
|
187
|
-
f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
135
|
+
f"bkg_{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
188
136
|
)
|
|
189
137
|
|
|
190
138
|
return count_rate
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class SpectralBackgroundModel(BackgroundModel):
|
|
142
|
+
def __init__(self, spectral_model: "SpectralModel", prior_distributions, sparse=False):
|
|
143
|
+
self.spectral_model = spectral_model
|
|
144
|
+
self.prior = PriorDictModel.from_dict(prior_distributions).nested_dict
|
|
145
|
+
self.sparse = sparse
|
|
146
|
+
|
|
147
|
+
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
148
|
+
params = build_prior(self.prior, prefix=f"_bkg_{name}")
|
|
149
|
+
bkg_model = jax.jit(
|
|
150
|
+
lambda par: forward_model(self.spectral_model, par, observation, sparse=self.sparse)
|
|
151
|
+
)
|
|
152
|
+
bkg_countrate = bkg_model(params)
|
|
153
|
+
|
|
154
|
+
with numpyro.plate("bkg_plate_" + name, len(observation.folded_background)):
|
|
155
|
+
numpyro.sample(
|
|
156
|
+
"bkg_" + name,
|
|
157
|
+
Poisson(bkg_countrate),
|
|
158
|
+
obs=observation.folded_background.data if observed else None,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return bkg_countrate
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
'''
|
|
165
|
+
class ConjugateBackground(BackgroundModel):
|
|
166
|
+
r"""
|
|
167
|
+
This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
|
|
168
|
+
distribution, we can analytically derive the posterior as a Negative binomial distribution.
|
|
169
|
+
|
|
170
|
+
$$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
|
|
171
|
+
p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
|
|
172
|
+
\right) $$
|
|
173
|
+
|
|
174
|
+
!!! info
|
|
175
|
+
Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
|
|
176
|
+
the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
|
|
177
|
+
$\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
|
|
178
|
+
counts are 0, and add a small scatter even if the measured background is effectively null.
|
|
179
|
+
|
|
180
|
+
??? abstract "References"
|
|
181
|
+
|
|
182
|
+
- https://en.wikipedia.org/wiki/Conjugate_prior
|
|
183
|
+
- https://www.acsu.buffalo.edu/~adamcunn/probability/gamma.html
|
|
184
|
+
- https://bayesiancomputationbook.com/markdown/chp_01.html?highlight=conjugate#conjugate-priors
|
|
185
|
+
- https://vioshyvo.github.io/Bayesian_inference/conjugate-distributions.html
|
|
186
|
+
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
190
|
+
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
191
|
+
# alpha = observed_counts + 1
|
|
192
|
+
# beta = 1
|
|
193
|
+
|
|
194
|
+
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
195
|
+
countrate = numpyro.sample(f"{name}", dist.Gamma(2 * observed_counts + 1, 2), obs=None)
|
|
196
|
+
|
|
197
|
+
return countrate
|
|
198
|
+
'''
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|