jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/__init__.py +0 -0
- jaxspec/_fit/_build_model.py +63 -0
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +238 -336
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +68 -11
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +101 -140
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +94 -87
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/METADATA +36 -16
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.3.dist-info/RECORD +0 -31
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/model/background.py
CHANGED
|
@@ -1,22 +1,27 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
|
|
3
|
+
import jax
|
|
3
4
|
import jax.numpy as jnp
|
|
4
5
|
import numpyro
|
|
5
6
|
import numpyro.distributions as dist
|
|
6
7
|
|
|
7
8
|
from jax.scipy.integrate import trapezoid
|
|
9
|
+
from numpyro.distributions import Poisson
|
|
8
10
|
from tinygp import GaussianProcess, kernels
|
|
9
11
|
|
|
12
|
+
from .._fit._build_model import build_prior, forward_model
|
|
13
|
+
from .abc import SpectralModel
|
|
14
|
+
|
|
10
15
|
|
|
11
16
|
class BackgroundModel(ABC):
|
|
12
17
|
"""
|
|
13
|
-
|
|
18
|
+
Handles the background modelling in our spectra. This is handled in a separate class for now
|
|
14
19
|
since backgrounds can be phenomenological models fitted directly on the folded spectrum. This is not the case for
|
|
15
20
|
the source model, which is fitted on the unfolded spectrum. This might be changed later.
|
|
16
21
|
"""
|
|
17
22
|
|
|
18
23
|
@abstractmethod
|
|
19
|
-
def numpyro_model(self,
|
|
24
|
+
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
20
25
|
"""
|
|
21
26
|
Build the model for the background.
|
|
22
27
|
"""
|
|
@@ -25,7 +30,7 @@ class BackgroundModel(ABC):
|
|
|
25
30
|
|
|
26
31
|
class SubtractedBackground(BackgroundModel):
|
|
27
32
|
"""
|
|
28
|
-
|
|
33
|
+
Define a model where the observed background is simply subtracted from the observed.
|
|
29
34
|
|
|
30
35
|
!!! danger
|
|
31
36
|
|
|
@@ -35,93 +40,40 @@ class SubtractedBackground(BackgroundModel):
|
|
|
35
40
|
|
|
36
41
|
"""
|
|
37
42
|
|
|
38
|
-
def numpyro_model(self,
|
|
39
|
-
_, observed_counts =
|
|
40
|
-
numpyro.deterministic(f"{name}", observed_counts)
|
|
43
|
+
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
44
|
+
_, observed_counts = observation.out_energies, observation.folded_background.data
|
|
45
|
+
numpyro.deterministic(f"bkg_{name}", observed_counts)
|
|
41
46
|
|
|
42
|
-
return
|
|
47
|
+
return observed_counts
|
|
43
48
|
|
|
44
49
|
|
|
45
50
|
class BackgroundWithError(BackgroundModel):
|
|
46
51
|
"""
|
|
47
|
-
|
|
52
|
+
Define a model where the observed background is subtracted from the observed accounting for its intrinsic spread. It
|
|
48
53
|
fits a countrate for each background bin assuming a Poisson distribution.
|
|
49
|
-
|
|
50
|
-
!!! warning
|
|
51
|
-
This is the same as [`ConjugateBackground`][jaxspec.model.background.ConjugateBackground]
|
|
52
|
-
but slower since it performs the fit using MCMC instead of analytical solution.
|
|
53
54
|
"""
|
|
54
55
|
|
|
55
|
-
def numpyro_model(self, obs,
|
|
56
|
+
def numpyro_model(self, obs, name: str = "", observed=True):
|
|
57
|
+
# We can't use the build_prior_function method here because the parameter size varies
|
|
58
|
+
# with the current observation. It must be instantiated in place.
|
|
56
59
|
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
60
|
+
|
|
57
61
|
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
58
62
|
alpha = observed_counts + 1
|
|
59
63
|
beta = 1
|
|
60
|
-
countrate = numpyro.sample(f"{name}
|
|
64
|
+
countrate = numpyro.sample(f"_bkg_{name}_countrate", dist.Gamma(alpha, rate=beta))
|
|
61
65
|
|
|
62
|
-
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
66
|
+
with numpyro.plate(f"bkg_{name}_plate", len(observed_counts)):
|
|
63
67
|
numpyro.sample(
|
|
64
|
-
f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
68
|
+
f"bkg_{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
65
69
|
)
|
|
66
70
|
|
|
67
71
|
return countrate
|
|
68
72
|
|
|
69
73
|
|
|
70
|
-
'''
|
|
71
|
-
# TODO: Implement this class and sample it with Gibbs Sampling
|
|
72
|
-
|
|
73
|
-
class ConjugateBackground(BackgroundModel):
|
|
74
|
-
r"""
|
|
75
|
-
This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
|
|
76
|
-
distribution, we can analytically derive the posterior as a Negative binomial distribution.
|
|
77
|
-
|
|
78
|
-
$$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
|
|
79
|
-
p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
|
|
80
|
-
\right) $$
|
|
81
|
-
|
|
82
|
-
!!! info
|
|
83
|
-
Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
|
|
84
|
-
the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
|
|
85
|
-
$\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
|
|
86
|
-
counts are 0, and add a small scatter even if the measured background is effectively null.
|
|
87
|
-
|
|
88
|
-
??? abstract "References"
|
|
89
|
-
|
|
90
|
-
- https://en.wikipedia.org/wiki/Conjugate_prior
|
|
91
|
-
- https://www.acsu.buffalo.edu/~adamcunn/probability/gamma.html
|
|
92
|
-
- https://bayesiancomputationbook.com/markdown/chp_01.html?highlight=conjugate#conjugate-priors
|
|
93
|
-
- https://vioshyvo.github.io/Bayesian_inference/conjugate-distributions.html
|
|
94
|
-
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
98
|
-
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
99
|
-
# alpha = observed_counts + 1
|
|
100
|
-
# beta = 1
|
|
101
|
-
|
|
102
|
-
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
103
|
-
countrate = numpyro.sample(f"{name}", dist.Gamma(2 * observed_counts + 1, 2), obs=None)
|
|
104
|
-
|
|
105
|
-
return countrate
|
|
106
|
-
'''
|
|
107
|
-
|
|
108
|
-
"""
|
|
109
|
-
class SpectralBackgroundModel(BackgroundModel):
|
|
110
|
-
# I should pass the current spectral model as an argument to the background model
|
|
111
|
-
# In the numpyro model function
|
|
112
|
-
def __init__(self, model, prior):
|
|
113
|
-
self.model = model
|
|
114
|
-
self.prior = prior
|
|
115
|
-
|
|
116
|
-
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
117
|
-
#TODO : keep the sparsification from top model
|
|
118
|
-
transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
|
|
119
|
-
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
74
|
class GaussianProcessBackground(BackgroundModel):
|
|
123
75
|
"""
|
|
124
|
-
|
|
76
|
+
Define a Gaussian Process to model the background. The GP is built using the
|
|
125
77
|
[`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
|
|
126
78
|
"""
|
|
127
79
|
|
|
@@ -146,16 +98,7 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
146
98
|
self.n_nodes = n_nodes
|
|
147
99
|
self.kernel = kernel
|
|
148
100
|
|
|
149
|
-
def numpyro_model(self, obs,
|
|
150
|
-
"""
|
|
151
|
-
Build the model for the background.
|
|
152
|
-
|
|
153
|
-
Parameters:
|
|
154
|
-
energy: The energy bins lower and upper values (e_low, e_high).
|
|
155
|
-
observed_counts: The observed counts in each energy bin.
|
|
156
|
-
name: The name of the background model for parameters disambiguation.
|
|
157
|
-
observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
|
|
158
|
-
"""
|
|
101
|
+
def numpyro_model(self, obs, name: str = "", observed=True):
|
|
159
102
|
energy, observed_counts = obs.out_energies, obs.folded_background.data
|
|
160
103
|
|
|
161
104
|
if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
|
|
@@ -163,28 +106,92 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
163
106
|
"More nodes than channels in the observation associated with GaussianProcessBackground."
|
|
164
107
|
)
|
|
165
108
|
|
|
109
|
+
else:
|
|
110
|
+
observed_counts = jnp.asarray(observed_counts)
|
|
111
|
+
|
|
166
112
|
# The parameters of the GP model
|
|
167
|
-
mean = numpyro.sample(
|
|
168
|
-
|
|
169
|
-
|
|
113
|
+
mean = numpyro.sample(
|
|
114
|
+
f"_bkg_{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0)
|
|
115
|
+
)
|
|
116
|
+
sigma = numpyro.sample(f"_bkg_{name}_sigma", dist.HalfNormal(3.0))
|
|
117
|
+
rho = numpyro.sample(f"_bkg_{name}_rho", dist.HalfNormal(10.0))
|
|
170
118
|
|
|
171
119
|
# Set up the kernel and GP objects
|
|
172
120
|
kernel = sigma**2 * self.kernel(rho)
|
|
173
121
|
nodes = jnp.linspace(0, 1, self.n_nodes)
|
|
174
122
|
gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
|
|
175
123
|
|
|
176
|
-
log_rate = numpyro.sample(f"
|
|
124
|
+
log_rate = numpyro.sample(f"_bkg_{name}_log_rate_nodes", gp.numpyro_dist())
|
|
125
|
+
|
|
177
126
|
interp_count_rate = jnp.exp(
|
|
178
127
|
jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
|
|
179
128
|
)
|
|
180
129
|
count_rate = trapezoid(interp_count_rate, energy, axis=0)
|
|
181
130
|
|
|
182
131
|
# Finally, our observation model is Poisson
|
|
183
|
-
with numpyro.plate(
|
|
184
|
-
# TODO : change to Poisson Likelihood when there is no background model
|
|
185
|
-
# TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
|
|
132
|
+
with numpyro.plate("bkg_plate_" + name, len(observed_counts)):
|
|
186
133
|
numpyro.sample(
|
|
187
|
-
f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
134
|
+
f"bkg_{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
188
135
|
)
|
|
189
136
|
|
|
190
137
|
return count_rate
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class SpectralModelBackground(BackgroundModel):
|
|
141
|
+
def __init__(self, spectral_model: "SpectralModel", prior_distributions, sparse=False):
|
|
142
|
+
self.spectral_model = spectral_model
|
|
143
|
+
self.prior = prior_distributions
|
|
144
|
+
self.sparse = sparse
|
|
145
|
+
|
|
146
|
+
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
147
|
+
params = build_prior(self.prior, prefix=f"_bkg_{name}_")
|
|
148
|
+
bkg_model = jax.jit(
|
|
149
|
+
lambda par: forward_model(self.spectral_model, par, observation, sparse=self.sparse)
|
|
150
|
+
)
|
|
151
|
+
bkg_countrate = bkg_model(params)
|
|
152
|
+
|
|
153
|
+
with numpyro.plate("bkg_plate_" + name, len(observation.folded_background)):
|
|
154
|
+
numpyro.sample(
|
|
155
|
+
"bkg_" + name,
|
|
156
|
+
Poisson(bkg_countrate),
|
|
157
|
+
obs=observation.folded_background.data if observed else None,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return bkg_countrate
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
'''
|
|
164
|
+
class ConjugateBackground(BackgroundModel):
|
|
165
|
+
r"""
|
|
166
|
+
This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
|
|
167
|
+
distribution, we can analytically derive the posterior as a Negative binomial distribution.
|
|
168
|
+
|
|
169
|
+
$$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
|
|
170
|
+
p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
|
|
171
|
+
\right) $$
|
|
172
|
+
|
|
173
|
+
!!! info
|
|
174
|
+
Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
|
|
175
|
+
the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
|
|
176
|
+
$\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
|
|
177
|
+
counts are 0, and add a small scatter even if the measured background is effectively null.
|
|
178
|
+
|
|
179
|
+
??? abstract "References"
|
|
180
|
+
|
|
181
|
+
- https://en.wikipedia.org/wiki/Conjugate_prior
|
|
182
|
+
- https://www.acsu.buffalo.edu/~adamcunn/probability/gamma.html
|
|
183
|
+
- https://bayesiancomputationbook.com/markdown/chp_01.html?highlight=conjugate#conjugate-priors
|
|
184
|
+
- https://vioshyvo.github.io/Bayesian_inference/conjugate-distributions.html
|
|
185
|
+
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
189
|
+
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
190
|
+
# alpha = observed_counts + 1
|
|
191
|
+
# beta = 1
|
|
192
|
+
|
|
193
|
+
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
194
|
+
countrate = numpyro.sample(f"{name}", dist.Gamma(2 * observed_counts + 1, 2), obs=None)
|
|
195
|
+
|
|
196
|
+
return countrate
|
|
197
|
+
'''
|
jaxspec/model/multiplicative.py
CHANGED
|
@@ -1,16 +1,31 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import flax.nnx as nnx
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
from astropy.table import Table
|
|
8
|
-
from haiku.initializers import Constant as HaikuConstant
|
|
9
8
|
|
|
10
9
|
from ..util.online_storage import table_manager
|
|
11
10
|
from .abc import MultiplicativeComponent
|
|
12
11
|
|
|
13
12
|
|
|
13
|
+
class MultiplicativeConstant(MultiplicativeComponent):
|
|
14
|
+
r"""
|
|
15
|
+
A multiplicative constant
|
|
16
|
+
|
|
17
|
+
!!! abstract "Parameters"
|
|
18
|
+
* $K$ (`norm`) $\left[\text{dimensionless}\right]$: The multiplicative constant.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self.norm = nnx.Param(1.0)
|
|
24
|
+
|
|
25
|
+
def factor(self, energy):
|
|
26
|
+
return self.norm
|
|
27
|
+
|
|
28
|
+
|
|
14
29
|
class Expfac(MultiplicativeComponent):
|
|
15
30
|
r"""
|
|
16
31
|
An exponential modification of a spectrum.
|
|
@@ -20,19 +35,20 @@ class Expfac(MultiplicativeComponent):
|
|
|
20
35
|
\text{if $E>E_c$}\\1 & \text{if $E<E_c$}\end{cases}
|
|
21
36
|
$$
|
|
22
37
|
|
|
23
|
-
|
|
24
|
-
* $A$
|
|
25
|
-
* $f$
|
|
26
|
-
* $E_c$
|
|
38
|
+
!!! abstract "Parameters"
|
|
39
|
+
* $A$ (`A`) $\left[\text{dimensionless}\right]$ : Amplitude of the modification
|
|
40
|
+
* $f$ (`f`) $\left[\text{keV}^{-1}\right]$ : Exponential factor
|
|
41
|
+
* $E_c$ (`E_c`) $\left[\text{keV}\right]$: Start energy of modification
|
|
27
42
|
|
|
28
43
|
"""
|
|
29
44
|
|
|
30
|
-
def
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
45
|
+
def __init__(self):
|
|
46
|
+
self.A = nnx.Param(1.0)
|
|
47
|
+
self.f = nnx.Param(1.0)
|
|
48
|
+
self.E_c = nnx.Param(1.0)
|
|
34
49
|
|
|
35
|
-
|
|
50
|
+
def factor(self, energy):
|
|
51
|
+
return jnp.where(energy >= self.E_c, 1.0 + self.A * jnp.exp(-self.f * energy), 1.0)
|
|
36
52
|
|
|
37
53
|
|
|
38
54
|
class Tbabs(MultiplicativeComponent):
|
|
@@ -45,49 +61,47 @@ class Tbabs(MultiplicativeComponent):
|
|
|
45
61
|
\mathcal{M}(E) = \exp^{-N_{\text{H}}\sigma(E)}
|
|
46
62
|
$$
|
|
47
63
|
|
|
48
|
-
|
|
49
|
-
* $N_{\text{H}}$ : Equivalent hydrogen column density
|
|
50
|
-
|
|
64
|
+
!!! abstract "Parameters"
|
|
65
|
+
* $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
|
|
66
|
+
|
|
51
67
|
|
|
52
68
|
!!! note
|
|
53
69
|
Abundances and cross-sections $\sigma$ can be found in Wilms et al. (2000).
|
|
54
70
|
|
|
55
71
|
"""
|
|
56
72
|
|
|
57
|
-
def __init__(self
|
|
58
|
-
super().__init__(*args, **kwargs)
|
|
73
|
+
def __init__(self):
|
|
59
74
|
table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
|
|
60
|
-
self.energy =
|
|
61
|
-
self.sigma =
|
|
75
|
+
self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
|
|
76
|
+
self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
|
|
77
|
+
self.nh = nnx.Param(1.0)
|
|
62
78
|
|
|
63
|
-
def
|
|
64
|
-
nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
|
|
79
|
+
def factor(self, energy):
|
|
65
80
|
sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
|
|
66
81
|
|
|
67
|
-
return jnp.exp(-nh * sigma)
|
|
82
|
+
return jnp.exp(-self.nh * sigma)
|
|
68
83
|
|
|
69
84
|
|
|
70
85
|
class Phabs(MultiplicativeComponent):
|
|
71
86
|
r"""
|
|
72
87
|
A photoelectric absorption model.
|
|
73
88
|
|
|
74
|
-
|
|
75
|
-
* $N_{\text{H}}$ : Equivalent hydrogen column density
|
|
76
|
-
|
|
89
|
+
!!! abstract "Parameters"
|
|
90
|
+
* $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
|
|
91
|
+
|
|
77
92
|
|
|
78
93
|
"""
|
|
79
94
|
|
|
80
|
-
def __init__(self
|
|
81
|
-
super().__init__(*args, **kwargs)
|
|
95
|
+
def __init__(self):
|
|
82
96
|
table = Table.read(table_manager.fetch("xsect_phabs_aspl.fits"))
|
|
83
|
-
self.energy =
|
|
84
|
-
self.sigma =
|
|
97
|
+
self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
|
|
98
|
+
self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
|
|
99
|
+
self.nh = nnx.Param(1.0)
|
|
85
100
|
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
sigma = jnp.interp(energy, self.energy, self.sigma, left=jnp.inf, right=0.0)
|
|
101
|
+
def factor(self, energy):
|
|
102
|
+
sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
|
|
89
103
|
|
|
90
|
-
return jnp.exp(-nh * sigma)
|
|
104
|
+
return jnp.exp(-self.nh * sigma)
|
|
91
105
|
|
|
92
106
|
|
|
93
107
|
class Wabs(MultiplicativeComponent):
|
|
@@ -95,22 +109,19 @@ class Wabs(MultiplicativeComponent):
|
|
|
95
109
|
A photo-electric absorption using Wisconsin (Morrison & McCammon 1983) cross-sections.
|
|
96
110
|
|
|
97
111
|
??? abstract "Parameters"
|
|
98
|
-
* $N_{\text{H}}$ : Equivalent hydrogen column density
|
|
99
|
-
$\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$
|
|
100
|
-
|
|
112
|
+
* $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
|
|
101
113
|
"""
|
|
102
114
|
|
|
103
|
-
def __init__(self
|
|
104
|
-
super().__init__(*args, **kwargs)
|
|
115
|
+
def __init__(self):
|
|
105
116
|
table = Table.read(table_manager.fetch("xsect_wabs_angr.fits"))
|
|
106
|
-
self.energy =
|
|
107
|
-
self.sigma =
|
|
117
|
+
self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
|
|
118
|
+
self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
|
|
119
|
+
self.nh = nnx.Param(1.0)
|
|
108
120
|
|
|
109
|
-
def
|
|
110
|
-
nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
|
|
121
|
+
def factor(self, energy):
|
|
111
122
|
sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
|
|
112
123
|
|
|
113
|
-
return jnp.exp(-nh * sigma)
|
|
124
|
+
return jnp.exp(-self.nh * sigma)
|
|
114
125
|
|
|
115
126
|
|
|
116
127
|
class Gabs(MultiplicativeComponent):
|
|
@@ -122,23 +133,26 @@ class Gabs(MultiplicativeComponent):
|
|
|
122
133
|
\left( -\frac{\left(E-E_0\right)^2}{2 \sigma^2} \right) \right)
|
|
123
134
|
$$
|
|
124
135
|
|
|
125
|
-
|
|
126
|
-
* $\tau$
|
|
127
|
-
* $\sigma$
|
|
128
|
-
* $E_0$
|
|
136
|
+
!!! abstract "Parameters"
|
|
137
|
+
* $\tau$ (`tau`) $\left[\text{dimensionless}\right]$ : Absorption strength
|
|
138
|
+
* $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Absorption width
|
|
139
|
+
* $E_0$ (`E0`) $\left[\text{keV}\right]$ : Absorption center
|
|
129
140
|
|
|
130
141
|
!!! note
|
|
131
142
|
The optical depth at line center is $\tau/(\sqrt{2 \pi} \sigma)$.
|
|
132
143
|
|
|
133
144
|
"""
|
|
134
145
|
|
|
135
|
-
def
|
|
136
|
-
tau =
|
|
137
|
-
sigma =
|
|
138
|
-
|
|
146
|
+
def __init__(self):
|
|
147
|
+
self.tau = nnx.Param(1.0)
|
|
148
|
+
self.sigma = nnx.Param(1.0)
|
|
149
|
+
self.E0 = nnx.Param(1.0)
|
|
139
150
|
|
|
151
|
+
def factor(self, energy):
|
|
140
152
|
return jnp.exp(
|
|
141
|
-
-tau
|
|
153
|
+
-self.tau
|
|
154
|
+
/ (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
|
155
|
+
* jnp.exp(-0.5 * ((energy - self.E0) / self.sigma) ** 2)
|
|
142
156
|
)
|
|
143
157
|
|
|
144
158
|
|
|
@@ -151,16 +165,17 @@ class Highecut(MultiplicativeComponent):
|
|
|
151
165
|
\left( \frac{E_c - E}{E_f} \right)& \text{if $E > E_c$}\\ 1 & \text{if $E < E_c$}\end{cases}
|
|
152
166
|
$$
|
|
153
167
|
|
|
154
|
-
|
|
155
|
-
* $E_c$
|
|
156
|
-
* $E_f$
|
|
168
|
+
!!! abstract "Parameters"
|
|
169
|
+
* $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Cutoff energy
|
|
170
|
+
* $E_f$ (`Ef`) $\left[\text{keV}\right]$ : Folding energy
|
|
157
171
|
"""
|
|
158
172
|
|
|
159
|
-
def
|
|
160
|
-
|
|
161
|
-
|
|
173
|
+
def __init__(self):
|
|
174
|
+
self.Ec = nnx.Param(1.0)
|
|
175
|
+
self.Ef = nnx.Param(1.0)
|
|
162
176
|
|
|
163
|
-
|
|
177
|
+
def factor(self, energy):
|
|
178
|
+
return jnp.where(energy <= self.Ec, 1.0, jnp.exp((self.Ec - energy) / self.Ef))
|
|
164
179
|
|
|
165
180
|
|
|
166
181
|
class Zedge(MultiplicativeComponent):
|
|
@@ -172,18 +187,21 @@ class Zedge(MultiplicativeComponent):
|
|
|
172
187
|
& \text{if $E > E_c$}\\ 1 & \text{if $E < E_c$}\end{cases}
|
|
173
188
|
$$
|
|
174
189
|
|
|
175
|
-
|
|
176
|
-
* $E_c$ : Threshold energy
|
|
177
|
-
* $
|
|
178
|
-
* $z$
|
|
190
|
+
!!! abstract "Parameters"
|
|
191
|
+
* $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Threshold energy
|
|
192
|
+
* $D$ (`D`) $\left[\text{dimensionless}\right]$ : Absorption depth at the threshold
|
|
193
|
+
* $z$ (`z`) $\left[\text{dimensionless}\right]$ : Redshift
|
|
179
194
|
"""
|
|
180
195
|
|
|
181
|
-
def
|
|
182
|
-
|
|
183
|
-
D =
|
|
184
|
-
z =
|
|
196
|
+
def __init__(self):
|
|
197
|
+
self.Ec = nnx.Param(1.0)
|
|
198
|
+
self.D = nnx.Param(1.0)
|
|
199
|
+
self.z = nnx.Param(0.0)
|
|
185
200
|
|
|
186
|
-
|
|
201
|
+
def factor(self, energy):
|
|
202
|
+
return jnp.where(
|
|
203
|
+
energy <= self.Ec, 1.0, jnp.exp(-self.D * (energy * (1 + self.z) / self.Ec) ** 3)
|
|
204
|
+
)
|
|
187
205
|
|
|
188
206
|
|
|
189
207
|
class Tbpcf(MultiplicativeComponent):
|
|
@@ -194,28 +212,25 @@ class Tbpcf(MultiplicativeComponent):
|
|
|
194
212
|
\mathcal{M}(E) = f \exp^{-N_{\text{H}}\sigma(E)} + (1-f)
|
|
195
213
|
$$
|
|
196
214
|
|
|
197
|
-
|
|
198
|
-
* $N_{\text{H}}$ : Equivalent hydrogen column density
|
|
199
|
-
|
|
200
|
-
* $f$ : Partial covering fraction, ranges between 0 and 1 [dimensionless]
|
|
215
|
+
!!! abstract "Parameters"
|
|
216
|
+
* $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
|
|
217
|
+
* $f$ (`f`) $\left[\text{dimensionless}\right]$ : Partial covering fraction, ranges between 0 and 1
|
|
201
218
|
|
|
202
219
|
!!! note
|
|
203
220
|
Abundances and cross-sections $\sigma$ can be found in Wilms et al. (2000).
|
|
204
221
|
|
|
205
222
|
"""
|
|
206
223
|
|
|
207
|
-
def __init__(self
|
|
208
|
-
super().__init__(*args, **kwargs)
|
|
224
|
+
def __init__(self):
|
|
209
225
|
table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
|
|
210
|
-
self.energy =
|
|
211
|
-
self.sigma =
|
|
226
|
+
self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
|
|
227
|
+
self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
|
|
228
|
+
self.nh = nnx.Param(1.0)
|
|
229
|
+
self.f = nnx.Param(0.2)
|
|
212
230
|
|
|
213
231
|
def continuum(self, energy):
|
|
214
|
-
nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
|
|
215
|
-
f = hk.get_parameter("f", [], float, init=HaikuConstant(0.2))
|
|
216
232
|
sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
|
|
217
|
-
|
|
218
|
-
return f * jnp.exp(-nh * sigma) + (1 - f)
|
|
233
|
+
return self.f * jnp.exp(-self.nh * sigma) + (1 - self.f)
|
|
219
234
|
|
|
220
235
|
|
|
221
236
|
class FDcut(MultiplicativeComponent):
|
|
@@ -227,12 +242,13 @@ class FDcut(MultiplicativeComponent):
|
|
|
227
242
|
$$
|
|
228
243
|
|
|
229
244
|
??? abstract "Parameters"
|
|
230
|
-
* $E_c$
|
|
231
|
-
* $E_f$
|
|
245
|
+
* $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Cutoff energy
|
|
246
|
+
* $E_f$ (`Ef`) $\left[\text{keV}\right]$ : Folding energy
|
|
232
247
|
"""
|
|
233
248
|
|
|
234
|
-
def
|
|
235
|
-
|
|
236
|
-
|
|
249
|
+
def __init__(self):
|
|
250
|
+
self.Ec = nnx.Param(1.0)
|
|
251
|
+
self.Ef = nnx.Param(3.0)
|
|
237
252
|
|
|
238
|
-
|
|
253
|
+
def continuum(self, energy):
|
|
254
|
+
return (1 + jnp.exp((energy - self.Ec) / self.Ef)) ** -1
|
jaxspec/scripts/debug.py
CHANGED
jaxspec/util/__init__.py
CHANGED
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from contextlib import contextmanager
|
|
3
|
-
from time import perf_counter
|
|
4
|
-
|
|
5
|
-
import haiku as hk
|
|
6
|
-
|
|
7
|
-
from jax.random import PRNGKey, split
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@contextmanager
|
|
11
|
-
def catchtime(desc="Task", print_time=True) -> Callable[[], float]:
|
|
12
|
-
"""
|
|
13
|
-
Context manager to measure time taken by a task.
|
|
14
|
-
|
|
15
|
-
Parameters
|
|
16
|
-
----------
|
|
17
|
-
desc (str): Description of the task.
|
|
18
|
-
print_time (bool): Whether to print the time taken by the task.
|
|
19
|
-
|
|
20
|
-
Returns
|
|
21
|
-
-------
|
|
22
|
-
Callable[[], float]: Function to get the time taken by the task.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
t1 = t2 = perf_counter()
|
|
26
|
-
yield lambda: t2 - t1
|
|
27
|
-
t2 = perf_counter()
|
|
28
|
-
if print_time:
|
|
29
|
-
print(f"{desc}: {t2 - t1:.3f} seconds")
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def sample_prior(dict_of_prior, key=PRNGKey(0), flat_parameters=False):
|
|
33
|
-
"""
|
|
34
|
-
Sample the prior distribution from a dict of distributions
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
parameters = dict(hk.data_structures.to_haiku_dict(dict_of_prior))
|
|
38
|
-
parameters_flat = {}
|
|
39
|
-
|
|
40
|
-
for m, n, distribution in hk.data_structures.traverse(dict_of_prior):
|
|
41
|
-
key, subkey = split(key)
|
|
42
|
-
parameters[m][n] = distribution.sample(subkey)
|
|
43
|
-
parameters_flat[m + "_" + n] = distribution.sample(subkey)
|
|
44
|
-
|
|
45
|
-
return parameters if not flat_parameters else parameters_flat
|
jaxspec/util/misc.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from time import perf_counter
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@contextmanager
|
|
7
|
+
def catchtime(desc="Task", print_time=True) -> Callable[[], float]:
|
|
8
|
+
"""
|
|
9
|
+
Context manager to measure time taken by a task.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
desc (str): Description of the task.
|
|
14
|
+
print_time (bool): Whether to print the time taken by the task.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
-------
|
|
18
|
+
Callable[[], float]: Function to get the time taken by the task.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
t1 = t2 = perf_counter()
|
|
22
|
+
yield lambda: t2 - t1
|
|
23
|
+
t2 = perf_counter()
|
|
24
|
+
if print_time:
|
|
25
|
+
print(f"{desc}: {t2 - t1:.3f} seconds")
|