jaxspec 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +250 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +29 -20
- jaxspec/fit.py +329 -81
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +27 -13
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +20 -25
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +15 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/METADATA +12 -9
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/RECORD +19 -22
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/new_apec.nc +0 -0
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/WHEEL +0 -0
jaxspec/model/_additive/apec.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
4
5
|
import astropy.units as u
|
|
6
|
+
import haiku as hk
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
|
|
10
|
+
from astropy.constants import c, m_p
|
|
11
|
+
from haiku.initializers import Constant as HaikuConstant
|
|
5
12
|
from jax import lax
|
|
6
|
-
from jax.lax import
|
|
13
|
+
from jax.lax import fori_loop, scan
|
|
7
14
|
from jax.scipy.stats import norm as gaussian
|
|
8
|
-
|
|
15
|
+
|
|
9
16
|
from ...util.abundance import abundance_table, element_data
|
|
10
|
-
from haiku.initializers import Constant as HaikuConstant
|
|
11
|
-
from astropy.constants import c, m_p
|
|
12
17
|
from ..abc import AdditiveComponent
|
|
13
|
-
from .apec_loaders import
|
|
18
|
+
from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
@jax.jit
|
|
@@ -51,7 +56,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
|
|
|
51
56
|
# Within
|
|
52
57
|
|
|
53
58
|
current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
|
|
54
|
-
previous_energy_is_between = (energy_low <= previous_energy) * (
|
|
59
|
+
previous_energy_is_between = (energy_low <= previous_energy) * (
|
|
60
|
+
previous_energy < energy_high
|
|
61
|
+
)
|
|
55
62
|
energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
|
|
56
63
|
|
|
57
64
|
case = (
|
|
@@ -68,7 +75,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
|
|
|
68
75
|
lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
|
|
69
76
|
lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
|
|
70
77
|
lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
|
|
71
|
-
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
|
|
78
|
+
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
|
|
79
|
+
* (er - el)
|
|
80
|
+
/ 2,
|
|
72
81
|
# 5
|
|
73
82
|
],
|
|
74
83
|
previous_energy,
|
|
@@ -86,23 +95,24 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
|
|
|
86
95
|
return integrated_flux
|
|
87
96
|
|
|
88
97
|
|
|
98
|
+
@jax.jit
|
|
89
99
|
def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
|
|
90
100
|
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
91
101
|
|
|
92
|
-
return (
|
|
102
|
+
return (
|
|
103
|
+
jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
|
|
104
|
+
) / (e_high - e_low)
|
|
93
105
|
|
|
94
106
|
|
|
95
|
-
|
|
107
|
+
@jax.jit
|
|
108
|
+
def interp_flux(energy, energy_ref, continuum_ref, end_index):
|
|
96
109
|
"""
|
|
97
110
|
Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
|
|
98
111
|
"""
|
|
99
112
|
|
|
100
113
|
def scanned_func(carry, unpack):
|
|
101
114
|
e_low, e_high = unpack
|
|
102
|
-
|
|
103
|
-
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
104
|
-
else:
|
|
105
|
-
continuum = interp(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
115
|
+
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
106
116
|
|
|
107
117
|
return carry, continuum
|
|
108
118
|
|
|
@@ -111,7 +121,8 @@ def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
|
|
|
111
121
|
return continuum
|
|
112
122
|
|
|
113
123
|
|
|
114
|
-
|
|
124
|
+
@jax.jit
|
|
125
|
+
def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
|
|
115
126
|
"""
|
|
116
127
|
Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
|
|
117
128
|
and weight the flux depending on the abundance of each element
|
|
@@ -119,7 +130,7 @@ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundance
|
|
|
119
130
|
|
|
120
131
|
def scanned_func(_, unpack):
|
|
121
132
|
energy_ref, continuum_ref, end_idx = unpack
|
|
122
|
-
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx
|
|
133
|
+
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
|
|
123
134
|
|
|
124
135
|
return _, element_flux
|
|
125
136
|
|
|
@@ -136,23 +147,9 @@ def get_lines_contribution_broadening(
|
|
|
136
147
|
# Notice the -1 in line element to match the 0-based indexing
|
|
137
148
|
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
138
149
|
broadening = l_energy * total_broadening[l_element]
|
|
139
|
-
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
return flux + l_flux
|
|
143
|
-
|
|
144
|
-
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@jax.jit
|
|
148
|
-
def get_lines_contribution_broadening_derivative(
|
|
149
|
-
line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
|
|
150
|
-
):
|
|
151
|
-
def body_func(i, flux):
|
|
152
|
-
# Notice the -1 in line element to match the 0-based indexing
|
|
153
|
-
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
154
|
-
broadening = l_energy * total_broadening[l_element]
|
|
155
|
-
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
|
|
150
|
+
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
|
|
151
|
+
energy[:-1], l_energy, broadening
|
|
152
|
+
)
|
|
156
153
|
l_flux = l_flux * l_emissivity * abundances[l_element]
|
|
157
154
|
|
|
158
155
|
return flux + l_flux
|
|
@@ -160,7 +157,6 @@ def get_lines_contribution_broadening_derivative(
|
|
|
160
157
|
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
161
158
|
|
|
162
159
|
|
|
163
|
-
@jax.custom_jvp
|
|
164
160
|
@jax.jit
|
|
165
161
|
def continuum_func(energy, kT, abundances):
|
|
166
162
|
idx, kT_low, kT_high = get_temperature(kT)
|
|
@@ -170,96 +166,27 @@ def continuum_func(energy, kT, abundances):
|
|
|
170
166
|
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
171
167
|
|
|
172
168
|
|
|
173
|
-
@jax.jit
|
|
174
|
-
@continuum_func.defjvp
|
|
175
|
-
def continuum_jvp(primals, tangents):
|
|
176
|
-
energy, kT, abundances = primals
|
|
177
|
-
energy_dot, kT_dot, abundances_dot = tangents
|
|
178
|
-
|
|
179
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
180
|
-
continuum_low_pars = get_continuum(idx)
|
|
181
|
-
continuum_high_pars = get_continuum(idx + 1)
|
|
182
|
-
|
|
183
|
-
# Energy derivative
|
|
184
|
-
dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
|
|
185
|
-
dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
|
|
186
|
-
energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
|
|
187
|
-
|
|
188
|
-
# Temperature derivative
|
|
189
|
-
continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
|
|
190
|
-
continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
|
|
191
|
-
kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
|
|
192
|
-
|
|
193
|
-
# Abundances derivative
|
|
194
|
-
dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
|
|
195
|
-
dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
|
|
196
|
-
abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
|
|
197
|
-
|
|
198
|
-
primals_out = continuum_func(*primals)
|
|
199
|
-
|
|
200
|
-
return primals_out, energy_derivative + kT_derivative + abundances_derivative
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
@jax.custom_jvp
|
|
204
169
|
@jax.jit
|
|
205
170
|
def pseudo_func(energy, kT, abundances):
|
|
206
171
|
idx, kT_low, kT_high = get_temperature(kT)
|
|
207
|
-
continuum_low = interp_flux_elements(*
|
|
208
|
-
continuum_high = interp_flux_elements(*
|
|
172
|
+
continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
|
|
173
|
+
continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
|
|
209
174
|
|
|
210
175
|
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
211
176
|
|
|
212
177
|
|
|
213
|
-
@jax.
|
|
214
|
-
@pseudo_func.defjvp
|
|
215
|
-
def pseudo_jvp(primals, tangents):
|
|
216
|
-
energy, kT, abundances = primals
|
|
217
|
-
energy_dot, kT_dot, abundances_dot = tangents
|
|
218
|
-
|
|
219
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
220
|
-
continuum_low_pars = get_pseudo(idx)
|
|
221
|
-
continuum_high_pars = get_pseudo(idx + 1)
|
|
222
|
-
|
|
223
|
-
# Energy derivative
|
|
224
|
-
dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
|
|
225
|
-
dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
|
|
226
|
-
energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
|
|
227
|
-
|
|
228
|
-
# Temperature derivative
|
|
229
|
-
continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
|
|
230
|
-
continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
|
|
231
|
-
kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
|
|
232
|
-
|
|
233
|
-
# Abundances derivative
|
|
234
|
-
dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
|
|
235
|
-
dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
|
|
236
|
-
abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
|
|
237
|
-
|
|
238
|
-
primals_out = pseudo_func(*primals)
|
|
239
|
-
|
|
240
|
-
return primals_out, energy_derivative + kT_derivative + abundances_derivative
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
@jax.custom_jvp
|
|
178
|
+
# @jax.custom_jvp
|
|
244
179
|
@jax.jit
|
|
245
180
|
def lines_func(energy, kT, abundances, broadening):
|
|
246
181
|
idx, kT_low, kT_high = get_temperature(kT)
|
|
247
182
|
line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
|
|
248
|
-
line_high = get_lines_contribution_broadening(
|
|
183
|
+
line_high = get_lines_contribution_broadening(
|
|
184
|
+
*get_lines(idx + 1), energy, abundances, broadening
|
|
185
|
+
)
|
|
249
186
|
|
|
250
187
|
return lerp(kT, kT_low, kT_high, line_low, line_high)
|
|
251
188
|
|
|
252
189
|
|
|
253
|
-
@jax.jit
|
|
254
|
-
@lines_func.defjvp
|
|
255
|
-
def lines_jvp(primals, tangents):
|
|
256
|
-
energy, kT, abundances, broadening = primals
|
|
257
|
-
energy_dot, kT_dot, abundances_dot, broadening_dot = tangents
|
|
258
|
-
|
|
259
|
-
primals_out = lines_func(*primals)
|
|
260
|
-
return primals_out, jnp.zeros_like(primals_out)
|
|
261
|
-
|
|
262
|
-
|
|
263
190
|
class APEC(AdditiveComponent):
|
|
264
191
|
"""
|
|
265
192
|
APEC model implementation in pure JAX for X-ray spectral fitting.
|
|
@@ -276,11 +203,15 @@ class APEC(AdditiveComponent):
|
|
|
276
203
|
thermal_broadening: bool = True,
|
|
277
204
|
turbulent_broadening: bool = True,
|
|
278
205
|
variant: Literal["none", "v", "vv"] = "none",
|
|
279
|
-
abundance_table: Literal[
|
|
206
|
+
abundance_table: Literal[
|
|
207
|
+
"angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
|
|
208
|
+
] = "angr",
|
|
280
209
|
trace_abundance: float = 1.0,
|
|
281
210
|
**kwargs,
|
|
282
211
|
):
|
|
283
|
-
super(
|
|
212
|
+
super().__init__(**kwargs)
|
|
213
|
+
|
|
214
|
+
warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
|
|
284
215
|
|
|
285
216
|
self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
|
|
286
217
|
|
|
@@ -325,14 +256,18 @@ class APEC(AdditiveComponent):
|
|
|
325
256
|
"""
|
|
326
257
|
if self.turbulent_broadening:
|
|
327
258
|
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
328
|
-
return
|
|
259
|
+
return (
|
|
260
|
+
hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
|
|
261
|
+
)
|
|
329
262
|
else:
|
|
330
263
|
return 0.0
|
|
331
264
|
|
|
332
265
|
def get_parameters(self):
|
|
333
266
|
none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
334
267
|
v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
335
|
-
trace_elements =
|
|
268
|
+
trace_elements = (
|
|
269
|
+
jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
|
|
270
|
+
)
|
|
336
271
|
|
|
337
272
|
# Set abundances of trace element (will be overwritten in the vv case)
|
|
338
273
|
abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
|
|
@@ -354,7 +289,9 @@ class APEC(AdditiveComponent):
|
|
|
354
289
|
abund = abund.at[i].set(Z)
|
|
355
290
|
|
|
356
291
|
if abund != "angr":
|
|
357
|
-
abund = abund * jnp.asarray(
|
|
292
|
+
abund = abund * jnp.asarray(
|
|
293
|
+
abundance_table[self.abundance_table] / abundance_table["angr"]
|
|
294
|
+
)
|
|
358
295
|
|
|
359
296
|
# Set the temperature, redshift, normalisation
|
|
360
297
|
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
@@ -372,6 +309,8 @@ class APEC(AdditiveComponent):
|
|
|
372
309
|
|
|
373
310
|
continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
|
|
374
311
|
pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
|
|
375
|
-
lines =
|
|
312
|
+
lines = (
|
|
313
|
+
lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
|
|
314
|
+
)
|
|
376
315
|
|
|
377
316
|
return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
|
|
@@ -1,90 +1,73 @@
|
|
|
1
|
-
"""
|
|
2
|
-
pure callback to enable reading data from the files without saturating the memory.
|
|
1
|
+
"""This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
|
|
2
|
+
pure callback to enable reading data from the files without saturating the memory."""
|
|
3
3
|
|
|
4
|
+
import h5netcdf
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
|
-
import numpy as np
|
|
7
|
-
import importlib.resources
|
|
8
|
-
import xarray as xr
|
|
9
7
|
|
|
8
|
+
from ...util.online_storage import table_manager
|
|
10
9
|
|
|
11
|
-
apec_file = xr.open_dataset(importlib.resources.files("jaxspec") / "tables/apec.nc", engine="h5netcdf")
|
|
12
10
|
|
|
11
|
+
@jax.jit
|
|
12
|
+
def temperature_table_getter():
|
|
13
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
14
|
+
temperature = jnp.asarray(f["/temperature"])
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
idx = np.searchsorted(apec_file.temperature, kT) - 1
|
|
16
|
-
|
|
17
|
-
if idx > len(apec_file.temperature) - 2:
|
|
18
|
-
return idx, 0.0, 0.0
|
|
19
|
-
else:
|
|
20
|
-
return idx, float(apec_file.temperature[idx]), float(apec_file.temperature[idx + 1])
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def continuum_table_getter(idx):
|
|
24
|
-
if idx > len(apec_file.temperature) - 2:
|
|
25
|
-
continuum_energy_array = jnp.zeros_like(apec_file.continuum_energy[0])
|
|
26
|
-
continuum_emissivity_array = jnp.zeros_like(apec_file.continuum_emissivity[0])
|
|
27
|
-
end_index_continuum = jnp.zeros_like(apec_file.continuum_end_index[0])
|
|
28
|
-
|
|
29
|
-
else:
|
|
30
|
-
continuum_energy_array = jnp.asarray(apec_file.continuum_energy[idx])
|
|
31
|
-
continuum_emissivity_array = jnp.asarray(apec_file.continuum_emissivity[idx])
|
|
32
|
-
end_index_continuum = jnp.asarray(apec_file.continuum_end_index[idx])
|
|
33
|
-
|
|
34
|
-
return continuum_energy_array, continuum_emissivity_array, end_index_continuum
|
|
35
|
-
|
|
16
|
+
return temperature
|
|
36
17
|
|
|
37
|
-
def pseudo_table_getter(idx):
|
|
38
|
-
if idx > len(apec_file.temperature) - 2:
|
|
39
|
-
pseudo_energy_array = jnp.zeros_like(apec_file.pseudo_energy[0])
|
|
40
|
-
pseudo_emissivity_array = jnp.zeros_like(apec_file.pseudo_emissivity[0])
|
|
41
|
-
end_index_pseudo = jnp.zeros_like(apec_file.pseudo_end_index[0])
|
|
42
18
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
19
|
+
@jax.jit
|
|
20
|
+
def get_temperature(kT):
|
|
21
|
+
temperature = temperature_table_getter()
|
|
22
|
+
idx = jnp.searchsorted(temperature, kT) - 1
|
|
47
23
|
|
|
48
|
-
return
|
|
24
|
+
return idx, temperature[idx], temperature[idx + 1]
|
|
49
25
|
|
|
50
26
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
27
|
+
@jax.jit
|
|
28
|
+
def continuum_table_getter():
|
|
29
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
30
|
+
continuum_energy = jnp.asarray(f["/continuum_energy"])
|
|
31
|
+
continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
|
|
32
|
+
continuum_end_index = jnp.asarray(f["/continuum_end_index"])
|
|
57
33
|
|
|
58
|
-
|
|
59
|
-
line_energy_array = jnp.asarray(apec_file.line_energy[idx])
|
|
60
|
-
line_element_array = jnp.asarray(apec_file.line_element[idx])
|
|
61
|
-
line_emissivity_array = jnp.asarray(apec_file.line_emissivity[idx])
|
|
62
|
-
end_index_lines = jnp.asarray(apec_file.line_end_index[idx])
|
|
34
|
+
return continuum_energy, continuum_emissivity, continuum_end_index
|
|
63
35
|
|
|
64
|
-
return line_energy_array, line_element_array, line_emissivity_array, end_index_lines
|
|
65
36
|
|
|
37
|
+
@jax.jit
|
|
38
|
+
def pseudo_table_getter():
|
|
39
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
40
|
+
pseudo_energy = jnp.asarray(f["/pseudo_energy"])
|
|
41
|
+
pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
|
|
42
|
+
pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
|
|
66
43
|
|
|
67
|
-
|
|
68
|
-
pure_callback_continuum_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, continuum_table_getter(0)))
|
|
69
|
-
pure_callback_pseudo_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, pseudo_table_getter(0)))
|
|
70
|
-
pure_callback_line_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, lines_table_getter(0)))
|
|
44
|
+
return pseudo_energy, pseudo_emissivity, pseudo_end_index
|
|
71
45
|
|
|
72
46
|
|
|
73
47
|
@jax.jit
|
|
74
|
-
def
|
|
75
|
-
|
|
48
|
+
def line_table_getter():
|
|
49
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
50
|
+
line_energy = jnp.asarray(f["/line_energy"])
|
|
51
|
+
line_element = jnp.asarray(f["/line_element"])
|
|
52
|
+
line_emissivity = jnp.asarray(f["/line_emissivity"])
|
|
53
|
+
line_end_index = jnp.asarray(f["/line_end_index"])
|
|
54
|
+
|
|
55
|
+
return line_energy, line_element, line_emissivity, line_end_index
|
|
76
56
|
|
|
77
57
|
|
|
78
58
|
@jax.jit
|
|
79
59
|
def get_continuum(idx):
|
|
80
|
-
|
|
60
|
+
continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
|
|
61
|
+
return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
|
|
81
62
|
|
|
82
63
|
|
|
83
64
|
@jax.jit
|
|
84
65
|
def get_pseudo(idx):
|
|
85
|
-
|
|
66
|
+
pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
|
|
67
|
+
return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
|
|
86
68
|
|
|
87
69
|
|
|
88
70
|
@jax.jit
|
|
89
71
|
def get_lines(idx):
|
|
90
|
-
|
|
72
|
+
line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
|
|
73
|
+
return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]
|
jaxspec/model/additive.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import astropy.constants
|
|
4
|
+
import astropy.units as u
|
|
4
5
|
import haiku as hk
|
|
5
6
|
import jax
|
|
6
7
|
import jax.numpy as jnp
|
|
7
8
|
import jax.scipy as jsp
|
|
8
|
-
|
|
9
|
-
import astropy.constants
|
|
9
|
+
|
|
10
10
|
from haiku.initializers import Constant as HaikuConstant
|
|
11
|
+
|
|
11
12
|
from ..util.integrate import integrate_interval
|
|
13
|
+
|
|
14
|
+
# from ._additive.apec import APEC
|
|
12
15
|
from .abc import AdditiveComponent
|
|
13
|
-
from ._additive.apec import APEC # noqa: F401
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class Powerlaw(AdditiveComponent):
|
|
@@ -283,16 +285,16 @@ class Diskpbb(AdditiveComponent):
|
|
|
283
285
|
where $$p$$ is a free parameter. The standard disk model, diskbb, is recovered if $$p=0.75$$.
|
|
284
286
|
If radial advection is important then $$p<0.75$$.
|
|
285
287
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
+
$$\\mathcal{M}\\left( E \right) = \frac{2\\pi(\\cos i)r^{2}_{\text{in}}}{pd^2} \\int_{T_{\text{in}}}^{T_{\text{out}}}
|
|
289
|
+
\\left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
|
|
288
290
|
|
|
289
291
|
??? abstract "Parameters"
|
|
290
|
-
* $\text{norm}$ :
|
|
291
|
-
where $r_{\text{in}}$ is "an apparent" inner disk radius
|
|
292
|
+
* $\text{norm}$ : $\\cos i(r_{\text{in}}/d)^{2}$,
|
|
293
|
+
where $r_{\text{in}}$ is "an apparent" inner disk radius $\\left[\text{km}\right]$,
|
|
292
294
|
$d$ the distance to the source in units of $10 \text{kpc}$,
|
|
293
295
|
$i$ the angle of the disk ($i=0$ is face-on)
|
|
294
|
-
* $p$ : Exponent of the radial dependence of the disk temperature
|
|
295
|
-
* $T_{\text{in}}$ : Temperature at inner disk radius
|
|
296
|
+
* $p$ : Exponent of the radial dependence of the disk temperature $\\left[\text{dimensionless}\right]$
|
|
297
|
+
* $T_{\text{in}}$ : Temperature at inner disk radius $\\left[ \\mathrm{keV}\right]$
|
|
296
298
|
"""
|
|
297
299
|
|
|
298
300
|
def continuum(self, energy):
|
|
@@ -332,7 +334,13 @@ class Diskbb(AdditiveComponent):
|
|
|
332
334
|
return e**2 * (kT / tin) ** (-2 / p - 1) / (jnp.exp(e / kT) - 1)
|
|
333
335
|
|
|
334
336
|
integral = integrate_interval(integrand)
|
|
335
|
-
return
|
|
337
|
+
return (
|
|
338
|
+
norm
|
|
339
|
+
* 2.78e-3
|
|
340
|
+
* (0.75 / p)
|
|
341
|
+
/ tin
|
|
342
|
+
* jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
|
|
343
|
+
)
|
|
336
344
|
|
|
337
345
|
|
|
338
346
|
class Agauss(AdditiveComponent):
|
|
@@ -381,7 +389,11 @@ class Zagauss(AdditiveComponent):
|
|
|
381
389
|
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
382
390
|
redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
|
|
383
391
|
|
|
384
|
-
return
|
|
392
|
+
return (
|
|
393
|
+
norm
|
|
394
|
+
* (1 + redshift)
|
|
395
|
+
* jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
|
|
396
|
+
)
|
|
385
397
|
|
|
386
398
|
|
|
387
399
|
class Zgauss(AdditiveComponent):
|
|
@@ -404,4 +416,6 @@ class Zgauss(AdditiveComponent):
|
|
|
404
416
|
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
|
|
405
417
|
redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
|
|
406
418
|
|
|
407
|
-
return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
|
|
419
|
+
return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
|
|
420
|
+
energy * (1 + redshift), loc=line_energy, scale=sigma
|
|
421
|
+
)
|
jaxspec/model/background.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
|
|
3
|
-
from jax.scipy.integrate import trapezoid
|
|
4
|
-
import numpyro.distributions as dist
|
|
2
|
+
|
|
5
3
|
import jax.numpy as jnp
|
|
6
4
|
import numpyro
|
|
5
|
+
import numpyro.distributions as dist
|
|
6
|
+
|
|
7
|
+
from jax.scipy.integrate import trapezoid
|
|
8
|
+
from tinygp import GaussianProcess, kernels
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class BackgroundModel(ABC):
|
|
@@ -33,7 +35,8 @@ class SubtractedBackground(BackgroundModel):
|
|
|
33
35
|
|
|
34
36
|
"""
|
|
35
37
|
|
|
36
|
-
def numpyro_model(self,
|
|
38
|
+
def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
|
|
39
|
+
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
37
40
|
numpyro.deterministic(f"{name}", observed_counts)
|
|
38
41
|
|
|
39
42
|
return jnp.zeros_like(observed_counts)
|
|
@@ -49,32 +52,37 @@ class BackgroundWithError(BackgroundModel):
|
|
|
49
52
|
but slower since it performs the fit using MCMC instead of analytical solution.
|
|
50
53
|
"""
|
|
51
54
|
|
|
52
|
-
def numpyro_model(self,
|
|
55
|
+
def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
|
|
53
56
|
# Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
|
|
57
|
+
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
54
58
|
alpha = observed_counts + 1
|
|
55
59
|
beta = 1
|
|
56
60
|
countrate = numpyro.sample(f"{name}_params", dist.Gamma(alpha, rate=beta))
|
|
57
61
|
|
|
58
62
|
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
59
|
-
numpyro.sample(
|
|
63
|
+
numpyro.sample(
|
|
64
|
+
f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
65
|
+
)
|
|
60
66
|
|
|
61
67
|
return countrate
|
|
62
68
|
|
|
63
69
|
|
|
64
70
|
'''
|
|
71
|
+
# TODO: Implement this class and sample it with Gibbs Sampling
|
|
72
|
+
|
|
65
73
|
class ConjugateBackground(BackgroundModel):
|
|
66
74
|
r"""
|
|
67
|
-
This class fit an expected rate
|
|
75
|
+
This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
|
|
68
76
|
distribution, we can analytically derive the posterior as a Negative binomial distribution.
|
|
69
77
|
|
|
70
|
-
$$ p(
|
|
71
|
-
p
|
|
78
|
+
$$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
|
|
79
|
+
p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
|
|
72
80
|
\right) $$
|
|
73
81
|
|
|
74
82
|
!!! info
|
|
75
83
|
Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
|
|
76
|
-
the prior distribution is such that
|
|
77
|
-
$\text{Var}[
|
|
84
|
+
the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
|
|
85
|
+
$\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
|
|
78
86
|
counts are 0, and add a small scatter even if the measured background is effectively null.
|
|
79
87
|
|
|
80
88
|
??? abstract "References"
|
|
@@ -97,6 +105,19 @@ class ConjugateBackground(BackgroundModel):
|
|
|
97
105
|
return countrate
|
|
98
106
|
'''
|
|
99
107
|
|
|
108
|
+
"""
|
|
109
|
+
class SpectralBackgroundModel(BackgroundModel):
|
|
110
|
+
# I should pass the current spectral model as an argument to the background model
|
|
111
|
+
# In the numpyro model function
|
|
112
|
+
def __init__(self, model, prior):
|
|
113
|
+
self.model = model
|
|
114
|
+
self.prior = prior
|
|
115
|
+
|
|
116
|
+
def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
|
|
117
|
+
#TODO : keep the sparsification from top model
|
|
118
|
+
transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
|
|
119
|
+
"""
|
|
120
|
+
|
|
100
121
|
|
|
101
122
|
class GaussianProcessBackground(BackgroundModel):
|
|
102
123
|
"""
|
|
@@ -104,7 +125,13 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
104
125
|
[`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
|
|
105
126
|
"""
|
|
106
127
|
|
|
107
|
-
def __init__(
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
e_min: float,
|
|
131
|
+
e_max: float,
|
|
132
|
+
n_nodes: int = 30,
|
|
133
|
+
kernel: kernels.Kernel = kernels.Matern52,
|
|
134
|
+
):
|
|
108
135
|
"""
|
|
109
136
|
Build the Gaussian Process background model.
|
|
110
137
|
|
|
@@ -119,7 +146,7 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
119
146
|
self.n_nodes = n_nodes
|
|
120
147
|
self.kernel = kernel
|
|
121
148
|
|
|
122
|
-
def numpyro_model(self,
|
|
149
|
+
def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
|
|
123
150
|
"""
|
|
124
151
|
Build the model for the background.
|
|
125
152
|
|
|
@@ -129,9 +156,12 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
129
156
|
name: The name of the background model for parameters disambiguation.
|
|
130
157
|
observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
|
|
131
158
|
"""
|
|
159
|
+
energy, observed_counts = obs.out_energies, obs.folded_background.data
|
|
132
160
|
|
|
133
161
|
if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
|
|
134
|
-
raise RuntimeError(
|
|
162
|
+
raise RuntimeError(
|
|
163
|
+
"More nodes than channels in the observation associated with GaussianProcessBackground."
|
|
164
|
+
)
|
|
135
165
|
|
|
136
166
|
# The parameters of the GP model
|
|
137
167
|
mean = numpyro.sample(f"{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0))
|
|
@@ -144,13 +174,17 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
144
174
|
gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
|
|
145
175
|
|
|
146
176
|
log_rate = numpyro.sample(f"_{name}_log_rate_nodes", gp.numpyro_dist())
|
|
147
|
-
interp_count_rate = jnp.exp(
|
|
177
|
+
interp_count_rate = jnp.exp(
|
|
178
|
+
jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
|
|
179
|
+
)
|
|
148
180
|
count_rate = trapezoid(interp_count_rate, energy, axis=0)
|
|
149
181
|
|
|
150
182
|
# Finally, our observation model is Poisson
|
|
151
183
|
with numpyro.plate(f"{name}_plate", len(observed_counts)):
|
|
152
184
|
# TODO : change to Poisson Likelihood when there is no background model
|
|
153
185
|
# TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
|
|
154
|
-
numpyro.sample(
|
|
186
|
+
numpyro.sample(
|
|
187
|
+
f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
188
|
+
)
|
|
155
189
|
|
|
156
190
|
return count_rate
|