jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +297 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +114 -84
- jaxspec/fit.py +335 -96
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +194 -55
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +63 -41
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +28 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/METADATA +14 -10
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/RECORD +19 -25
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/WHEEL +1 -1
- jaxspec/data/example_data/MOS1.pha +0 -46
- jaxspec/data/example_data/MOS2.pha +0 -42
- jaxspec/data/example_data/PN.pha +1 -293
- jaxspec/data/example_data/fakeit.pha +1 -335
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/LICENSE.md +0 -0
jaxspec/model/additive.py
CHANGED
|
@@ -1,16 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import astropy.constants
|
|
4
|
+
import astropy.units as u
|
|
4
5
|
import haiku as hk
|
|
6
|
+
import interpax
|
|
5
7
|
import jax
|
|
6
8
|
import jax.numpy as jnp
|
|
7
9
|
import jax.scipy as jsp
|
|
8
|
-
import
|
|
9
|
-
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from astropy.table import Table
|
|
10
13
|
from haiku.initializers import Constant as HaikuConstant
|
|
14
|
+
|
|
11
15
|
from ..util.integrate import integrate_interval
|
|
16
|
+
from ..util.online_storage import table_manager
|
|
17
|
+
|
|
18
|
+
# from ._additive.apec import APEC
|
|
12
19
|
from .abc import AdditiveComponent
|
|
13
|
-
# from ._additive.apec import APEC # noqa: F401
|
|
14
20
|
|
|
15
21
|
|
|
16
22
|
class Powerlaw(AdditiveComponent):
|
|
@@ -26,8 +32,8 @@ class Powerlaw(AdditiveComponent):
|
|
|
26
32
|
"""
|
|
27
33
|
|
|
28
34
|
def continuum(self, energy):
|
|
29
|
-
alpha = hk.get_parameter("alpha", [], init=HaikuConstant(1.3))
|
|
30
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1e-4))
|
|
35
|
+
alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
|
|
36
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
|
|
31
37
|
|
|
32
38
|
return norm * energy ** (-alpha)
|
|
33
39
|
|
|
@@ -43,12 +49,12 @@ class AdditiveConstant(AdditiveComponent):
|
|
|
43
49
|
"""
|
|
44
50
|
|
|
45
51
|
def continuum(self, energy):
|
|
46
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
52
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
47
53
|
|
|
48
54
|
return norm * jnp.ones_like(energy)
|
|
49
55
|
|
|
50
56
|
def primitive(self, energy):
|
|
51
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
57
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
52
58
|
|
|
53
59
|
return norm * energy
|
|
54
60
|
|
|
@@ -66,9 +72,9 @@ class Lorentz(AdditiveComponent):
|
|
|
66
72
|
"""
|
|
67
73
|
|
|
68
74
|
def continuum(self, energy):
|
|
69
|
-
line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
|
|
70
|
-
sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
|
|
71
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
75
|
+
line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
|
|
76
|
+
sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
|
|
77
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
72
78
|
|
|
73
79
|
return norm * sigma / (2 * jnp.pi) / ((energy - line_energy) ** 2 + (sigma / 2) ** 2)
|
|
74
80
|
|
|
@@ -91,9 +97,9 @@ class Logparabola(AdditiveComponent):
|
|
|
91
97
|
|
|
92
98
|
# TODO : conform with xspec definition
|
|
93
99
|
def continuum(self, energy):
|
|
94
|
-
a = hk.get_parameter("a", [], init=HaikuConstant(11 / 3))
|
|
95
|
-
b = hk.get_parameter("b", [], init=HaikuConstant(0.2))
|
|
96
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
100
|
+
a = hk.get_parameter("a", [], float, init=HaikuConstant(11 / 3))
|
|
101
|
+
b = hk.get_parameter("b", [], float, init=HaikuConstant(0.2))
|
|
102
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
97
103
|
|
|
98
104
|
return norm * energy ** (-(a + b * jnp.log(energy)))
|
|
99
105
|
|
|
@@ -111,8 +117,8 @@ class Blackbody(AdditiveComponent):
|
|
|
111
117
|
"""
|
|
112
118
|
|
|
113
119
|
def continuum(self, energy):
|
|
114
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(11 / 3))
|
|
115
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
120
|
+
kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
|
|
121
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
116
122
|
|
|
117
123
|
return norm * 8.0525 * energy**2 / ((kT**4) * (jnp.exp(energy / kT) - 1))
|
|
118
124
|
|
|
@@ -130,8 +136,8 @@ class Blackbodyrad(AdditiveComponent):
|
|
|
130
136
|
"""
|
|
131
137
|
|
|
132
138
|
def continuum(self, energy):
|
|
133
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(11 / 3))
|
|
134
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
139
|
+
kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
|
|
140
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
135
141
|
|
|
136
142
|
return norm * 1.0344e-3 * energy**2 / (jnp.exp(energy / kT) - 1)
|
|
137
143
|
|
|
@@ -150,9 +156,9 @@ class Gauss(AdditiveComponent):
|
|
|
150
156
|
"""
|
|
151
157
|
|
|
152
158
|
def continuum(self, energy) -> (jax.Array, jax.Array):
|
|
153
|
-
line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
|
|
154
|
-
sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
|
|
155
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
159
|
+
line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
|
|
160
|
+
sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
|
|
161
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
156
162
|
|
|
157
163
|
return norm * jsp.stats.norm.pdf(energy, loc=line_energy, scale=sigma)
|
|
158
164
|
|
|
@@ -198,9 +204,9 @@ class APEC(AdditiveComponent):
|
|
|
198
204
|
)
|
|
199
205
|
|
|
200
206
|
def mono_fine_structure(self, e_low, e_high) -> (jax.Array, jax.Array):
|
|
201
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
202
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(1))
|
|
203
|
-
Z = hk.get_parameter("Z", [], init=HaikuConstant(1))
|
|
207
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
208
|
+
kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
|
|
209
|
+
Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
|
|
204
210
|
|
|
205
211
|
idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1
|
|
206
212
|
|
|
@@ -229,9 +235,9 @@ class APEC(AdditiveComponent):
|
|
|
229
235
|
|
|
230
236
|
@partial(jnp.vectorize, excluded=(0,))
|
|
231
237
|
def continuum(self, energy):
|
|
232
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
233
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(1))
|
|
234
|
-
Z = hk.get_parameter("Z", [], init=HaikuConstant(1))
|
|
238
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
239
|
+
kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
|
|
240
|
+
Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
|
|
235
241
|
|
|
236
242
|
idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1 # index of left value
|
|
237
243
|
|
|
@@ -269,9 +275,9 @@ class Cutoffpl(AdditiveComponent):
|
|
|
269
275
|
"""
|
|
270
276
|
|
|
271
277
|
def continuum(self, energy):
|
|
272
|
-
alpha = hk.get_parameter("alpha", [], init=HaikuConstant(1.3))
|
|
273
|
-
beta = hk.get_parameter("beta", [], init=HaikuConstant(15))
|
|
274
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1e-4))
|
|
278
|
+
alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
|
|
279
|
+
beta = hk.get_parameter("beta", [], float, init=HaikuConstant(15))
|
|
280
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
|
|
275
281
|
|
|
276
282
|
return norm * energy ** (-alpha) * jnp.exp(-energy / beta)
|
|
277
283
|
|
|
@@ -283,22 +289,22 @@ class Diskpbb(AdditiveComponent):
|
|
|
283
289
|
where $$p$$ is a free parameter. The standard disk model, diskbb, is recovered if $$p=0.75$$.
|
|
284
290
|
If radial advection is important then $$p<0.75$$.
|
|
285
291
|
|
|
286
|
-
|
|
287
|
-
|
|
292
|
+
$$\\mathcal{M}\\left( E \right) = \frac{2\\pi(\\cos i)r^{2}_{\text{in}}}{pd^2} \\int_{T_{\text{in}}}^{T_{\text{out}}}
|
|
293
|
+
\\left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
|
|
288
294
|
|
|
289
295
|
??? abstract "Parameters"
|
|
290
|
-
* $\text{norm}$ :
|
|
291
|
-
where $r_{\text{in}}$ is "an apparent" inner disk radius
|
|
296
|
+
* $\text{norm}$ : $\\cos i(r_{\text{in}}/d)^{2}$,
|
|
297
|
+
where $r_{\text{in}}$ is "an apparent" inner disk radius $\\left[\text{km}\right]$,
|
|
292
298
|
$d$ the distance to the source in units of $10 \text{kpc}$,
|
|
293
299
|
$i$ the angle of the disk ($i=0$ is face-on)
|
|
294
|
-
* $p$ : Exponent of the radial dependence of the disk temperature
|
|
295
|
-
* $T_{\text{in}}$ : Temperature at inner disk radius
|
|
300
|
+
* $p$ : Exponent of the radial dependence of the disk temperature $\\left[\text{dimensionless}\right]$
|
|
301
|
+
* $T_{\text{in}}$ : Temperature at inner disk radius $\\left[ \\mathrm{keV}\right]$
|
|
296
302
|
"""
|
|
297
303
|
|
|
298
304
|
def continuum(self, energy):
|
|
299
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
300
|
-
p = hk.get_parameter("p", [], init=HaikuConstant(0.75))
|
|
301
|
-
tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
|
|
305
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
306
|
+
p = hk.get_parameter("p", [], float, init=HaikuConstant(0.75))
|
|
307
|
+
tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
|
|
302
308
|
|
|
303
309
|
# Tout is set to 0 as it is evaluated at R=infinity
|
|
304
310
|
def integrand(kT, energy):
|
|
@@ -324,15 +330,21 @@ class Diskbb(AdditiveComponent):
|
|
|
324
330
|
def continuum(self, energy):
|
|
325
331
|
p = 0.75
|
|
326
332
|
tout = 0.0
|
|
327
|
-
tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
|
|
328
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
333
|
+
tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
|
|
334
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
329
335
|
|
|
330
336
|
# Tout is set to 0 as it is evaluated at R=infinity
|
|
331
337
|
def integrand(kT, e, tin, p):
|
|
332
338
|
return e**2 * (kT / tin) ** (-2 / p - 1) / (jnp.exp(e / kT) - 1)
|
|
333
339
|
|
|
334
340
|
integral = integrate_interval(integrand)
|
|
335
|
-
return
|
|
341
|
+
return (
|
|
342
|
+
norm
|
|
343
|
+
* 2.78e-3
|
|
344
|
+
* (0.75 / p)
|
|
345
|
+
/ tin
|
|
346
|
+
* jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
|
|
347
|
+
)
|
|
336
348
|
|
|
337
349
|
|
|
338
350
|
class Agauss(AdditiveComponent):
|
|
@@ -352,9 +364,9 @@ class Agauss(AdditiveComponent):
|
|
|
352
364
|
|
|
353
365
|
def continuum(self, energy) -> (jax.Array, jax.Array):
|
|
354
366
|
hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
|
|
355
|
-
line_wavelength = hk.get_parameter("Lambda_l", [], init=HaikuConstant(hc))
|
|
356
|
-
sigma = hk.get_parameter("sigma", [], init=HaikuConstant(0.001))
|
|
357
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
367
|
+
line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
|
|
368
|
+
sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
|
|
369
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
358
370
|
|
|
359
371
|
return norm * jsp.stats.norm.pdf(hc / energy, loc=line_wavelength, scale=sigma)
|
|
360
372
|
|
|
@@ -376,12 +388,16 @@ class Zagauss(AdditiveComponent):
|
|
|
376
388
|
|
|
377
389
|
def continuum(self, energy) -> (jax.Array, jax.Array):
|
|
378
390
|
hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
|
|
379
|
-
line_wavelength = hk.get_parameter("Lambda_l", [], init=HaikuConstant(hc))
|
|
380
|
-
sigma = hk.get_parameter("sigma", [], init=HaikuConstant(0.001))
|
|
381
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
382
|
-
redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
|
|
391
|
+
line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
|
|
392
|
+
sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
|
|
393
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
394
|
+
redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
|
|
383
395
|
|
|
384
|
-
return
|
|
396
|
+
return (
|
|
397
|
+
norm
|
|
398
|
+
* (1 + redshift)
|
|
399
|
+
* jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
|
|
400
|
+
)
|
|
385
401
|
|
|
386
402
|
|
|
387
403
|
class Zgauss(AdditiveComponent):
|
|
@@ -399,9 +415,132 @@ class Zgauss(AdditiveComponent):
|
|
|
399
415
|
"""
|
|
400
416
|
|
|
401
417
|
def continuum(self, energy) -> (jax.Array, jax.Array):
|
|
402
|
-
line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
|
|
403
|
-
sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
|
|
404
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
405
|
-
redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
|
|
418
|
+
line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
|
|
419
|
+
sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
|
|
420
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
|
|
421
|
+
redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
|
|
422
|
+
|
|
423
|
+
return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
|
|
424
|
+
energy * (1 + redshift), loc=line_energy, scale=sigma
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class NSatmos(AdditiveComponent):
|
|
429
|
+
r"""
|
|
430
|
+
A neutron star atmosphere model based on the `NSATMOS` model from `XSPEC`. See [this page](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/node205.html)
|
|
431
|
+
|
|
432
|
+
!!! warning
|
|
433
|
+
The boundary case of $R_{\text{NS}} < 1.125 R_{\text{S}}$ is handled with a null flux instead of a constant value as in `XSPEC`.
|
|
434
|
+
|
|
435
|
+
??? abstract "Parameters"
|
|
436
|
+
* $T_{eff}$ : Effective temperature at the surface in K (No redshift applied)
|
|
437
|
+
* $M_{ns}$ : Mass of the NS in solar masses
|
|
438
|
+
* $R_∞$ : Radius at infinity (modulated by gravitational effects) in km
|
|
439
|
+
* $D$ : Distance to the neutron star in kpc
|
|
440
|
+
* norm : fraction of the neutron star surface emitting
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def __init__(self, *args, **kwargs):
|
|
444
|
+
super().__init__(*args, **kwargs)
|
|
445
|
+
entry_table = Table.read(table_manager.fetch("nsatmosdata.fits"), 1)
|
|
446
|
+
|
|
447
|
+
# Extract the table values. All this code could be summarize in two lines if we reformat the nsatmosdata.fits table
|
|
448
|
+
self.tab_temperature = np.asarray(entry_table["TEMP"][0], dtype=float) # Logarithmic value
|
|
449
|
+
self.tab_gravity = np.asarray(entry_table["GRAVITY"][0], dtype=float) # Logarithmic value
|
|
450
|
+
self.tab_mucrit = np.asarray(entry_table["MUCRIT"][0], dtype=float)
|
|
451
|
+
self.tab_energy = np.asarray(entry_table["ENERGY"][0], dtype=float)
|
|
452
|
+
self.tab_flux_flat = Table.read(table_manager.fetch("nsatmosdata.fits"), 2)["FLUX"]
|
|
453
|
+
|
|
454
|
+
tab_flux = np.empty(
|
|
455
|
+
(
|
|
456
|
+
self.tab_temperature.size,
|
|
457
|
+
self.tab_gravity.size,
|
|
458
|
+
self.tab_mucrit.size,
|
|
459
|
+
self.tab_energy.size,
|
|
460
|
+
)
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
for i in range(len(self.tab_temperature)):
|
|
464
|
+
for j in range(len(self.tab_gravity)):
|
|
465
|
+
for k in range(len(self.tab_mucrit)):
|
|
466
|
+
tab_flux[i, j, k] = np.array(
|
|
467
|
+
self.tab_flux_flat[
|
|
468
|
+
i * len(self.tab_gravity) * len(self.tab_mucrit)
|
|
469
|
+
+ j * len(self.tab_mucrit)
|
|
470
|
+
+ k
|
|
471
|
+
]
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
self.tab_flux = np.asarray(tab_flux, dtype=float)
|
|
475
|
+
|
|
476
|
+
def interp_flux_func(self, temperature_log, gravity_log, mu):
|
|
477
|
+
# Interpolate the tables to get the flux on the tabulated energy grid
|
|
478
|
+
|
|
479
|
+
return interpax.interp3d(
|
|
480
|
+
10.0**temperature_log,
|
|
481
|
+
10.0**gravity_log,
|
|
482
|
+
mu,
|
|
483
|
+
10.0**self.tab_temperature,
|
|
484
|
+
10.0**self.tab_gravity,
|
|
485
|
+
self.tab_mucrit,
|
|
486
|
+
self.tab_flux,
|
|
487
|
+
method="linear",
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
def continuum(self, energy):
|
|
491
|
+
temp_log = hk.get_parameter(
|
|
492
|
+
"Tinf", [], float, init=HaikuConstant(6.0)
|
|
493
|
+
) # log10 of temperature in Kelvin
|
|
494
|
+
|
|
495
|
+
# 'Tinf': temp_log, # 5 to 6.5
|
|
496
|
+
# 'M': mass, # 0.5 to 3
|
|
497
|
+
# 'Rns': radius, # 5 to 30
|
|
498
|
+
|
|
499
|
+
mass = hk.get_parameter("M", [], float, init=HaikuConstant(1.4))
|
|
500
|
+
radius = hk.get_parameter("Rns", [], float, init=HaikuConstant(10.0))
|
|
501
|
+
distance = hk.get_parameter("dns", [], float, init=HaikuConstant(10.0))
|
|
502
|
+
norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1.0))
|
|
503
|
+
|
|
504
|
+
# Derive parameters usable to retrive value in the flux table
|
|
505
|
+
Rcgs = 1e5 * radius # Radius in cgs
|
|
506
|
+
r_schwarzschild = 2.95e5 * mass # Schwarzschild radius in cgs
|
|
507
|
+
r_normalized = Rcgs / r_schwarzschild # Ratio of the radius to the Schwarzschild radius
|
|
508
|
+
r_over_Dsql = 2 * jnp.log10(
|
|
509
|
+
Rcgs / (3.09e21 * distance)
|
|
510
|
+
) # Log( (R/D)**2 ), 3.09e21 constant transforms radius in cgs to kpc
|
|
511
|
+
zred1 = 1 / jnp.sqrt(1 - (1 / r_normalized)) # Gravitational redshift
|
|
512
|
+
gravity = (6.67e-8 * 1.99e33 * mass / Rcgs**2) * zred1 # Gravity field g in cgs
|
|
513
|
+
gravity_log = jnp.log10(
|
|
514
|
+
gravity
|
|
515
|
+
) # Log gravity because this is the format given in the table
|
|
516
|
+
|
|
517
|
+
# Not sure about mu yet, but it is linked to causality
|
|
518
|
+
cmu = jnp.where(
|
|
519
|
+
r_normalized < 1.5, jnp.sqrt(1.0 - 6.75 / r_normalized**2 + 6.75 / r_normalized**3), 0.0
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# Interpolate the flux table to get the flux at the surface
|
|
523
|
+
|
|
524
|
+
flux = jax.jit(self.interp_flux_func)(temp_log, gravity_log, cmu)
|
|
525
|
+
|
|
526
|
+
# Rescale the photon energies and fluxes back to the correct local temperature
|
|
527
|
+
Tfactor = 10.0 ** (temp_log - 6.0)
|
|
528
|
+
fluxshift = 3.0 * (temp_log - 6.0)
|
|
529
|
+
E = self.tab_energy * Tfactor
|
|
530
|
+
flux += fluxshift
|
|
531
|
+
|
|
532
|
+
# Rescale applying redshift
|
|
533
|
+
fluxshift = -jnp.log10(zred1)
|
|
534
|
+
E = E / zred1
|
|
535
|
+
flux += fluxshift
|
|
536
|
+
|
|
537
|
+
# Convert to counts/keV (which corresponds to dividing by1.602e-9*EkeV)
|
|
538
|
+
# Multiply by the area of the star, and calculate the count rate at the observer
|
|
539
|
+
flux += r_over_Dsql
|
|
540
|
+
counts = 10.0 ** (flux - jnp.log10(1.602e-9) - jnp.log10(E))
|
|
541
|
+
|
|
542
|
+
true_flux = norm * jnp.exp(
|
|
543
|
+
interpax.interp1d(jnp.log(energy), jnp.log(E), jnp.log(counts), method="linear")
|
|
544
|
+
)
|
|
406
545
|
|
|
407
|
-
return (
|
|
546
|
+
return jax.lax.select(r_normalized < 1.125, jnp.zeros_like(true_flux), true_flux)
|
jaxspec/model/background.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
|
|
3
|
-
from jax.scipy.integrate import trapezoid
|
|
4
|
-
import numpyro.distributions as dist
|
|
2
|
+
|
|
5
3
|
import jax.numpy as jnp
|
|
6
4
|
import numpyro
|
|
5
|
+
import numpyro.distributions as dist
|
|
6
|
+
|
|
7
|
+
from jax.scipy.integrate import trapezoid
|
|
8
|
+
from tinygp import GaussianProcess, kernels
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class BackgroundModel(ABC):
|
|
@@ -33,7 +35,8 @@ class SubtractedBackground(BackgroundModel):
|
|
|
33
35
|
|
|
34
36
|
"""
|
|
35
37
|
|
|
36
|
-
def numpyro_model(self,
|
|
38
|
+
def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
|
|
39
|
+
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
37
40
|
numpyro.deterministic(f"{name}", observed_counts)
|
|
38
41
|
|
|
39
42
|
return jnp.zeros_like(observed_counts)
|
|
@@ -49,32 +52,37 @@ class BackgroundWithError(BackgroundModel):
|
|
|
49
52
|
but slower since it performs the fit using MCMC instead of analytical solution.
|
|
50
53
|
"""
|
|
51
54
|
|
|
52
|
-
def numpyro_model(self,
|
|
55
|
+
def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
|
|
53
56
|
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
57
|
+
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
54
58
|
alpha = observed_counts + 1
|
|
55
59
|
beta = 1
|
|
56
60
|
countrate = numpyro.sample(f"{name}_params", dist.Gamma(alpha, rate=beta))
|
|
57
61
|
|
|
58
62
|
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
59
|
-
numpyro.sample(
|
|
63
|
+
numpyro.sample(
|
|
64
|
+
f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
65
|
+
)
|
|
60
66
|
|
|
61
67
|
return countrate
|
|
62
68
|
|
|
63
69
|
|
|
64
70
|
'''
|
|
71
|
+
# TODO: Implement this class and sample it with Gibbs Sampling
|
|
72
|
+
|
|
65
73
|
class ConjugateBackground(BackgroundModel):
|
|
66
74
|
r"""
|
|
67
|
-
This class fit an expected rate
|
|
75
|
+
This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
|
|
68
76
|
distribution, we can analytically derive the posterior as a Negative binomial distribution.
|
|
69
77
|
|
|
70
|
-
$$ p(
|
|
71
|
-
p
|
|
78
|
+
$$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
|
|
79
|
+
p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
|
|
72
80
|
\right) $$
|
|
73
81
|
|
|
74
82
|
!!! info
|
|
75
83
|
Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
|
|
76
|
-
the prior distribution is such that
|
|
77
|
-
$\text{Var}[
|
|
84
|
+
the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
|
|
85
|
+
$\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
|
|
78
86
|
counts are 0, and add a small scatter even if the measured background is effectively null.
|
|
79
87
|
|
|
80
88
|
??? abstract "References"
|
|
@@ -97,6 +105,19 @@ class ConjugateBackground(BackgroundModel):
|
|
|
97
105
|
return countrate
|
|
98
106
|
'''
|
|
99
107
|
|
|
108
|
+
"""
|
|
109
|
+
class SpectralBackgroundModel(BackgroundModel):
|
|
110
|
+
# I should pass the current spectral model as an argument to the background model
|
|
111
|
+
# In the numpyro model function
|
|
112
|
+
def __init__(self, model, prior):
|
|
113
|
+
self.model = model
|
|
114
|
+
self.prior = prior
|
|
115
|
+
|
|
116
|
+
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
117
|
+
#TODO : keep the sparsification from top model
|
|
118
|
+
transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
|
|
119
|
+
"""
|
|
120
|
+
|
|
100
121
|
|
|
101
122
|
class GaussianProcessBackground(BackgroundModel):
|
|
102
123
|
"""
|
|
@@ -104,7 +125,13 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
104
125
|
[`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
|
|
105
126
|
"""
|
|
106
127
|
|
|
107
|
-
def __init__(
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
e_min: float,
|
|
131
|
+
e_max: float,
|
|
132
|
+
n_nodes: int = 30,
|
|
133
|
+
kernel: kernels.Kernel = kernels.Matern52,
|
|
134
|
+
):
|
|
108
135
|
"""
|
|
109
136
|
Build the Gaussian Process background model.
|
|
110
137
|
|
|
@@ -119,7 +146,7 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
119
146
|
self.n_nodes = n_nodes
|
|
120
147
|
self.kernel = kernel
|
|
121
148
|
|
|
122
|
-
def numpyro_model(self,
|
|
149
|
+
def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
|
|
123
150
|
"""
|
|
124
151
|
Build the model for the background.
|
|
125
152
|
|
|
@@ -129,9 +156,12 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
129
156
|
name: The name of the background model for parameters disambiguation.
|
|
130
157
|
observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
|
|
131
158
|
"""
|
|
159
|
+
energy, observed_counts = obs.out_energies, obs.folded_background.data
|
|
132
160
|
|
|
133
161
|
if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
|
|
134
|
-
raise RuntimeError(
|
|
162
|
+
raise RuntimeError(
|
|
163
|
+
"More nodes than channels in the observation associated with GaussianProcessBackground."
|
|
164
|
+
)
|
|
135
165
|
|
|
136
166
|
# The parameters of the GP model
|
|
137
167
|
mean = numpyro.sample(f"{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0))
|
|
@@ -144,13 +174,17 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
144
174
|
gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
|
|
145
175
|
|
|
146
176
|
log_rate = numpyro.sample(f"_{name}_log_rate_nodes", gp.numpyro_dist())
|
|
147
|
-
interp_count_rate = jnp.exp(
|
|
177
|
+
interp_count_rate = jnp.exp(
|
|
178
|
+
jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
|
|
179
|
+
)
|
|
148
180
|
count_rate = trapezoid(interp_count_rate, energy, axis=0)
|
|
149
181
|
|
|
150
182
|
# Finally, our observation model is Poisson
|
|
151
183
|
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
152
184
|
# TODO : change to Poisson Likelihood when there is no background model
|
|
153
185
|
# TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
|
|
154
|
-
numpyro.sample(
|
|
186
|
+
numpyro.sample(
|
|
187
|
+
f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
188
|
+
)
|
|
155
189
|
|
|
156
190
|
return count_rate
|