jaxspec 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/__init__.py +1 -1
- jaxspec/analysis/compare.py +3 -3
- jaxspec/analysis/results.py +239 -110
- jaxspec/data/instrument.py +0 -2
- jaxspec/data/ogip.py +18 -0
- jaxspec/data/util.py +11 -3
- jaxspec/fit.py +166 -72
- jaxspec/model/_additive/__init__.py +0 -0
- jaxspec/model/_additive/apec.py +377 -0
- jaxspec/model/_additive/apec_loaders.py +90 -0
- jaxspec/model/abc.py +55 -7
- jaxspec/model/additive.py +2 -51
- jaxspec/tables/abundances.dat +31 -0
- jaxspec/util/abundance.py +111 -0
- jaxspec/util/integrate.py +5 -4
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/METADATA +5 -3
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/RECORD +19 -14
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jax
|
|
3
|
+
import haiku as hk
|
|
4
|
+
import astropy.units as u
|
|
5
|
+
from jax import lax
|
|
6
|
+
from jax.lax import scan, fori_loop
|
|
7
|
+
from jax.scipy.stats import norm as gaussian
|
|
8
|
+
from typing import Literal
|
|
9
|
+
from ...util.abundance import abundance_table, element_data
|
|
10
|
+
from haiku.initializers import Constant as HaikuConstant
|
|
11
|
+
from astropy.constants import c, m_p
|
|
12
|
+
from ..abc import AdditiveComponent
|
|
13
|
+
from .apec_loaders import get_temperature, get_continuum, get_pseudo, get_lines
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@jax.jit
|
|
17
|
+
def lerp(x, x0, x1, y0, y1):
|
|
18
|
+
"""
|
|
19
|
+
Linear interpolation routine
|
|
20
|
+
Return y(x) = (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
|
|
21
|
+
"""
|
|
22
|
+
return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@jax.jit
|
|
26
|
+
def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end_index):
|
|
27
|
+
"""
|
|
28
|
+
This function interpolate & integrate the values of a tabulated reference continuum between two energy limits
|
|
29
|
+
Sorry for the boilerplate here, but be sure that it works !
|
|
30
|
+
|
|
31
|
+
Parameters:
|
|
32
|
+
energy_low: lower limit of the integral
|
|
33
|
+
energy_high: upper limit of the integral
|
|
34
|
+
energy_ref: energy grid of the reference continuum
|
|
35
|
+
continuum_ref: continuum values evaluated at energy_ref
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
39
|
+
start_index = jnp.searchsorted(energy_ref, energy_low, side="left") - 1
|
|
40
|
+
end_index = jnp.searchsorted(energy_ref, energy_high, side="left") + 1
|
|
41
|
+
|
|
42
|
+
def body_func(index, value):
|
|
43
|
+
integrated_flux, previous_energy, previous_continuum = value
|
|
44
|
+
current_energy, current_continuum = energy_ref[index], continuum_ref[index]
|
|
45
|
+
|
|
46
|
+
# 5 cases
|
|
47
|
+
# Neither current and previous energies are within the integral limits > nothing is added to the integrated flux
|
|
48
|
+
# The left limit of the integral is between the current and previous energy > previous energy is set to the limit, previous continuum is interpolated, and then added to the integrated flux
|
|
49
|
+
# The right limit of the integral is between the current and previous energy > current energy is set to the limit, current continuum is interpolated, and then added to the integrated flux
|
|
50
|
+
# Both current and previous energies are within the integral limits -> add to the integrated flux
|
|
51
|
+
# Within
|
|
52
|
+
|
|
53
|
+
current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
|
|
54
|
+
previous_energy_is_between = (energy_low <= previous_energy) * (previous_energy < energy_high)
|
|
55
|
+
energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
|
|
56
|
+
|
|
57
|
+
case = (
|
|
58
|
+
(1 - previous_energy_is_between) * current_energy_is_between * 1
|
|
59
|
+
+ previous_energy_is_between * (1 - current_energy_is_between) * 2
|
|
60
|
+
+ (previous_energy_is_between * current_energy_is_between) * 3
|
|
61
|
+
+ energies_within_bins * 4
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
term_to_add = lax.switch(
|
|
65
|
+
case,
|
|
66
|
+
[
|
|
67
|
+
lambda pe, pc, ce, cc, el, er: 0.0, # 1
|
|
68
|
+
lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
|
|
69
|
+
lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
|
|
70
|
+
lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
|
|
71
|
+
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc)) * (er - el) / 2,
|
|
72
|
+
# 5
|
|
73
|
+
],
|
|
74
|
+
previous_energy,
|
|
75
|
+
previous_continuum,
|
|
76
|
+
current_energy,
|
|
77
|
+
current_continuum,
|
|
78
|
+
energy_low,
|
|
79
|
+
energy_high,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return integrated_flux + term_to_add, current_energy, current_continuum
|
|
83
|
+
|
|
84
|
+
integrated_flux, _, _ = fori_loop(start_index, end_index, body_func, (0.0, 0.0, 0.0))
|
|
85
|
+
|
|
86
|
+
return integrated_flux
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
|
|
90
|
+
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
91
|
+
|
|
92
|
+
return (jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)) / (e_high - e_low)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
|
|
96
|
+
"""
|
|
97
|
+
Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def scanned_func(carry, unpack):
|
|
101
|
+
e_low, e_high = unpack
|
|
102
|
+
if integrate:
|
|
103
|
+
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
104
|
+
else:
|
|
105
|
+
continuum = interp(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
106
|
+
|
|
107
|
+
return carry, continuum
|
|
108
|
+
|
|
109
|
+
_, continuum = scan(scanned_func, 0.0, (energy[:-1], energy[1:]))
|
|
110
|
+
|
|
111
|
+
return continuum
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances, integrate=True):
|
|
115
|
+
"""
|
|
116
|
+
Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
|
|
117
|
+
and weight the flux depending on the abundance of each element
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def scanned_func(_, unpack):
|
|
121
|
+
energy_ref, continuum_ref, end_idx = unpack
|
|
122
|
+
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx, integrate=integrate)
|
|
123
|
+
|
|
124
|
+
return _, element_flux
|
|
125
|
+
|
|
126
|
+
_, flux = scan(scanned_func, 0.0, (energy_ref, continuum_ref, end_index))
|
|
127
|
+
|
|
128
|
+
return abundances @ flux
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@jax.jit
|
|
132
|
+
def get_lines_contribution_broadening(
|
|
133
|
+
line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
|
|
134
|
+
):
|
|
135
|
+
def body_func(i, flux):
|
|
136
|
+
# Notice the -1 in line element to match the 0-based indexing
|
|
137
|
+
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
138
|
+
broadening = l_energy * total_broadening[l_element]
|
|
139
|
+
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
|
|
140
|
+
l_flux = l_flux * l_emissivity * abundances[l_element]
|
|
141
|
+
|
|
142
|
+
return flux + l_flux
|
|
143
|
+
|
|
144
|
+
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@jax.jit
|
|
148
|
+
def get_lines_contribution_broadening_derivative(
|
|
149
|
+
line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
|
|
150
|
+
):
|
|
151
|
+
def body_func(i, flux):
|
|
152
|
+
# Notice the -1 in line element to match the 0-based indexing
|
|
153
|
+
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
154
|
+
broadening = l_energy * total_broadening[l_element]
|
|
155
|
+
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
|
|
156
|
+
l_flux = l_flux * l_emissivity * abundances[l_element]
|
|
157
|
+
|
|
158
|
+
return flux + l_flux
|
|
159
|
+
|
|
160
|
+
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@jax.custom_jvp
|
|
164
|
+
@jax.jit
|
|
165
|
+
def continuum_func(energy, kT, abundances):
|
|
166
|
+
idx, kT_low, kT_high = get_temperature(kT)
|
|
167
|
+
continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
|
|
168
|
+
continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
|
|
169
|
+
|
|
170
|
+
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@jax.jit
|
|
174
|
+
@continuum_func.defjvp
|
|
175
|
+
def continuum_jvp(primals, tangents):
|
|
176
|
+
energy, kT, abundances = primals
|
|
177
|
+
energy_dot, kT_dot, abundances_dot = tangents
|
|
178
|
+
|
|
179
|
+
idx, kT_low, kT_high = get_temperature(kT)
|
|
180
|
+
continuum_low_pars = get_continuum(idx)
|
|
181
|
+
continuum_high_pars = get_continuum(idx + 1)
|
|
182
|
+
|
|
183
|
+
# Energy derivative
|
|
184
|
+
dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
|
|
185
|
+
dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
|
|
186
|
+
energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
|
|
187
|
+
|
|
188
|
+
# Temperature derivative
|
|
189
|
+
continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
|
|
190
|
+
continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
|
|
191
|
+
kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
|
|
192
|
+
|
|
193
|
+
# Abundances derivative
|
|
194
|
+
dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
|
|
195
|
+
dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
|
|
196
|
+
abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
|
|
197
|
+
|
|
198
|
+
primals_out = continuum_func(*primals)
|
|
199
|
+
|
|
200
|
+
return primals_out, energy_derivative + kT_derivative + abundances_derivative
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@jax.custom_jvp
|
|
204
|
+
@jax.jit
|
|
205
|
+
def pseudo_func(energy, kT, abundances):
|
|
206
|
+
idx, kT_low, kT_high = get_temperature(kT)
|
|
207
|
+
continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
|
|
208
|
+
continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
|
|
209
|
+
|
|
210
|
+
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@jax.jit
|
|
214
|
+
@pseudo_func.defjvp
|
|
215
|
+
def pseudo_jvp(primals, tangents):
|
|
216
|
+
energy, kT, abundances = primals
|
|
217
|
+
energy_dot, kT_dot, abundances_dot = tangents
|
|
218
|
+
|
|
219
|
+
idx, kT_low, kT_high = get_temperature(kT)
|
|
220
|
+
continuum_low_pars = get_pseudo(idx)
|
|
221
|
+
continuum_high_pars = get_pseudo(idx + 1)
|
|
222
|
+
|
|
223
|
+
# Energy derivative
|
|
224
|
+
dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
|
|
225
|
+
dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
|
|
226
|
+
energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
|
|
227
|
+
|
|
228
|
+
# Temperature derivative
|
|
229
|
+
continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
|
|
230
|
+
continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
|
|
231
|
+
kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
|
|
232
|
+
|
|
233
|
+
# Abundances derivative
|
|
234
|
+
dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
|
|
235
|
+
dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
|
|
236
|
+
abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
|
|
237
|
+
|
|
238
|
+
primals_out = pseudo_func(*primals)
|
|
239
|
+
|
|
240
|
+
return primals_out, energy_derivative + kT_derivative + abundances_derivative
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@jax.custom_jvp
|
|
244
|
+
@jax.jit
|
|
245
|
+
def lines_func(energy, kT, abundances, broadening):
|
|
246
|
+
idx, kT_low, kT_high = get_temperature(kT)
|
|
247
|
+
line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
|
|
248
|
+
line_high = get_lines_contribution_broadening(*get_lines(idx + 1), energy, abundances, broadening)
|
|
249
|
+
|
|
250
|
+
return lerp(kT, kT_low, kT_high, line_low, line_high)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@jax.jit
|
|
254
|
+
@lines_func.defjvp
|
|
255
|
+
def lines_jvp(primals, tangents):
|
|
256
|
+
energy, kT, abundances, broadening = primals
|
|
257
|
+
energy_dot, kT_dot, abundances_dot, broadening_dot = tangents
|
|
258
|
+
|
|
259
|
+
primals_out = lines_func(*primals)
|
|
260
|
+
return primals_out, jnp.zeros_like(primals_out)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class APEC(AdditiveComponent):
|
|
264
|
+
"""
|
|
265
|
+
APEC model implementation in pure JAX for X-ray spectral fitting.
|
|
266
|
+
|
|
267
|
+
!!! warning
|
|
268
|
+
This implementation is optimised for the CPU, it shows poor performance on the GPU.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
continuum: bool = True,
|
|
274
|
+
pseudo: bool = True,
|
|
275
|
+
lines: bool = True,
|
|
276
|
+
thermal_broadening: bool = True,
|
|
277
|
+
turbulent_broadening: bool = True,
|
|
278
|
+
variant: Literal["none", "v", "vv"] = "none",
|
|
279
|
+
abundance_table: Literal["angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"] = "angr",
|
|
280
|
+
trace_abundance: float = 1.0,
|
|
281
|
+
**kwargs,
|
|
282
|
+
):
|
|
283
|
+
super(APEC, self).__init__(**kwargs)
|
|
284
|
+
|
|
285
|
+
self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
|
|
286
|
+
|
|
287
|
+
self.abundance_table = abundance_table
|
|
288
|
+
self.thermal_broadening = thermal_broadening
|
|
289
|
+
self.turbulent_broadening = turbulent_broadening
|
|
290
|
+
self.continuum_to_compute = continuum
|
|
291
|
+
self.pseudo_to_compute = pseudo
|
|
292
|
+
self.lines_to_compute = lines
|
|
293
|
+
self.trace_abundance = trace_abundance
|
|
294
|
+
self.variant = variant
|
|
295
|
+
|
|
296
|
+
def get_thermal_broadening(self):
|
|
297
|
+
r"""
|
|
298
|
+
Compute the thermal broadening $\sigma_T$ for each element using :
|
|
299
|
+
|
|
300
|
+
$$ \frac{\sigma_T}{E_{\text{line}}} = \frac{1}{c}\sqrt{\frac{k_{B} T}{A m_p}}$$
|
|
301
|
+
|
|
302
|
+
where $E_{\text{line}}$ is the energy of the line, $c$ is the speed of light, $k_{B}$ is the Boltzmann constant,
|
|
303
|
+
$T$ is the temperature, $A$ is the atomic weight of the element and $m_p$ is the proton mass.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
if self.thermal_broadening:
|
|
307
|
+
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
308
|
+
factor = 1 / c * (1 / m_p) ** (1 / 2)
|
|
309
|
+
factor = factor.to(u.keV ** (-1 / 2)).value
|
|
310
|
+
|
|
311
|
+
# Multiply this factor by Line_Energy * sqrt(kT/A) to get the broadening for a line
|
|
312
|
+
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
313
|
+
return factor * jnp.sqrt(kT / self.atomic_weights)
|
|
314
|
+
|
|
315
|
+
else:
|
|
316
|
+
return jnp.zeros((30,))
|
|
317
|
+
|
|
318
|
+
def get_turbulent_broadening(self):
|
|
319
|
+
r"""
|
|
320
|
+
Return the turbulent broadening using :
|
|
321
|
+
|
|
322
|
+
$$\frac{\sigma_\text{turb}}{E_{\text{line}}} = \frac{\sigma_{v ~ ||}}{c}$$
|
|
323
|
+
|
|
324
|
+
where $\sigma_{v ~ ||}$ is the velocity dispersion along the line of sight in km/s.
|
|
325
|
+
"""
|
|
326
|
+
if self.turbulent_broadening:
|
|
327
|
+
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
328
|
+
return hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
|
|
329
|
+
else:
|
|
330
|
+
return 0.0
|
|
331
|
+
|
|
332
|
+
def get_parameters(self):
|
|
333
|
+
none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
334
|
+
v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
335
|
+
trace_elements = jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
|
|
336
|
+
|
|
337
|
+
# Set abundances of trace element (will be overwritten in the vv case)
|
|
338
|
+
abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
|
|
339
|
+
|
|
340
|
+
if self.variant == "vv":
|
|
341
|
+
for i, element in enumerate(abundance_table["Element"]):
|
|
342
|
+
if element != "H":
|
|
343
|
+
abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
|
|
344
|
+
|
|
345
|
+
elif self.variant == "v":
|
|
346
|
+
for i, element in enumerate(abundance_table["Element"]):
|
|
347
|
+
if element != "H" and element in v_elements:
|
|
348
|
+
abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
|
|
349
|
+
|
|
350
|
+
else:
|
|
351
|
+
Z = hk.get_parameter("Abundance", [], init=HaikuConstant(1.0))
|
|
352
|
+
for i, element in enumerate(abundance_table["Element"]):
|
|
353
|
+
if element != "H" and element in none_elements:
|
|
354
|
+
abund = abund.at[i].set(Z)
|
|
355
|
+
|
|
356
|
+
if abund != "angr":
|
|
357
|
+
abund = abund * jnp.asarray(abundance_table[self.abundance_table] / abundance_table["angr"])
|
|
358
|
+
|
|
359
|
+
# Set the temperature, redshift, normalisation
|
|
360
|
+
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
361
|
+
z = hk.get_parameter("Redshift", [], init=HaikuConstant(0.0))
|
|
362
|
+
norm = hk.get_parameter("norm", [], init=HaikuConstant(1.0))
|
|
363
|
+
|
|
364
|
+
return kT, z, norm, abund
|
|
365
|
+
|
|
366
|
+
def emission_lines(self, e_low, e_high):
|
|
367
|
+
# Get the parameters and extract the relevant data
|
|
368
|
+
energy = jnp.hstack([e_low, e_high[-1]])
|
|
369
|
+
kT, z, norm, abundances = self.get_parameters()
|
|
370
|
+
total_broadening = jnp.hypot(self.get_thermal_broadening(), self.get_turbulent_broadening())
|
|
371
|
+
energy = energy * (1 + z)
|
|
372
|
+
|
|
373
|
+
continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
|
|
374
|
+
pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
|
|
375
|
+
lines = lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
|
|
376
|
+
|
|
377
|
+
return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
""" This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
|
|
2
|
+
pure callback to enable reading data from the files without saturating the memory. """
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
|
+
import importlib.resources
|
|
8
|
+
import xarray as xr
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
apec_file = xr.open_dataset(importlib.resources.files("jaxspec") / "tables/apec.nc", engine="h5netcdf")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def temperature_table_getter(kT):
|
|
15
|
+
idx = np.searchsorted(apec_file.temperature, kT) - 1
|
|
16
|
+
|
|
17
|
+
if idx > len(apec_file.temperature) - 2:
|
|
18
|
+
return idx, 0.0, 0.0
|
|
19
|
+
else:
|
|
20
|
+
return idx, float(apec_file.temperature[idx]), float(apec_file.temperature[idx + 1])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def continuum_table_getter(idx):
|
|
24
|
+
if idx > len(apec_file.temperature) - 2:
|
|
25
|
+
continuum_energy_array = jnp.zeros_like(apec_file.continuum_energy[0])
|
|
26
|
+
continuum_emissivity_array = jnp.zeros_like(apec_file.continuum_emissivity[0])
|
|
27
|
+
end_index_continuum = jnp.zeros_like(apec_file.continuum_end_index[0])
|
|
28
|
+
|
|
29
|
+
else:
|
|
30
|
+
continuum_energy_array = jnp.asarray(apec_file.continuum_energy[idx])
|
|
31
|
+
continuum_emissivity_array = jnp.asarray(apec_file.continuum_emissivity[idx])
|
|
32
|
+
end_index_continuum = jnp.asarray(apec_file.continuum_end_index[idx])
|
|
33
|
+
|
|
34
|
+
return continuum_energy_array, continuum_emissivity_array, end_index_continuum
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def pseudo_table_getter(idx):
|
|
38
|
+
if idx > len(apec_file.temperature) - 2:
|
|
39
|
+
pseudo_energy_array = jnp.zeros_like(apec_file.pseudo_energy[0])
|
|
40
|
+
pseudo_emissivity_array = jnp.zeros_like(apec_file.pseudo_emissivity[0])
|
|
41
|
+
end_index_pseudo = jnp.zeros_like(apec_file.pseudo_end_index[0])
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
pseudo_energy_array = jnp.asarray(apec_file.pseudo_energy[idx])
|
|
45
|
+
pseudo_emissivity_array = jnp.asarray(apec_file.pseudo_emissivity[idx])
|
|
46
|
+
end_index_pseudo = jnp.asarray(apec_file.pseudo_end_index[idx])
|
|
47
|
+
|
|
48
|
+
return pseudo_energy_array, pseudo_emissivity_array, end_index_pseudo
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def lines_table_getter(idx):
|
|
52
|
+
if idx > len(apec_file.temperature) - 2:
|
|
53
|
+
line_energy_array = jnp.zeros_like(apec_file.line_energy[0])
|
|
54
|
+
line_element_array = jnp.zeros_like(apec_file.line_element[0])
|
|
55
|
+
line_emissivity_array = jnp.zeros_like(apec_file.line_emissivity[0])
|
|
56
|
+
end_index_lines = jnp.zeros_like(apec_file.line_end_index[0])
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
line_energy_array = jnp.asarray(apec_file.line_energy[idx])
|
|
60
|
+
line_element_array = jnp.asarray(apec_file.line_element[idx])
|
|
61
|
+
line_emissivity_array = jnp.asarray(apec_file.line_emissivity[idx])
|
|
62
|
+
end_index_lines = jnp.asarray(apec_file.line_end_index[idx])
|
|
63
|
+
|
|
64
|
+
return line_energy_array, line_element_array, line_emissivity_array, end_index_lines
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
pure_callback_temperature_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, temperature_table_getter(10.0)))
|
|
68
|
+
pure_callback_continuum_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, continuum_table_getter(0)))
|
|
69
|
+
pure_callback_pseudo_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, pseudo_table_getter(0)))
|
|
70
|
+
pure_callback_line_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, lines_table_getter(0)))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@jax.jit
|
|
74
|
+
def get_temperature(kT):
|
|
75
|
+
return jax.pure_callback(temperature_table_getter, pure_callback_temperature_shape, kT)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@jax.jit
|
|
79
|
+
def get_continuum(idx):
|
|
80
|
+
return jax.pure_callback(continuum_table_getter, pure_callback_continuum_shape, idx)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@jax.jit
|
|
84
|
+
def get_pseudo(idx):
|
|
85
|
+
return jax.pure_callback(pseudo_table_getter, pure_callback_pseudo_shape, idx)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@jax.jit
|
|
89
|
+
def get_lines(idx):
|
|
90
|
+
return jax.pure_callback(lines_table_getter, pure_callback_line_shape, idx)
|
jaxspec/model/abc.py
CHANGED
|
@@ -123,9 +123,9 @@ class SpectralModel:
|
|
|
123
123
|
|
|
124
124
|
!!! info
|
|
125
125
|
This method is internally used in the inference process and should not be used directly. See
|
|
126
|
-
[`photon_flux`][jaxspec.analysis.results.
|
|
126
|
+
[`photon_flux`][jaxspec.analysis.results.FitResult.photon_flux] to compute
|
|
127
127
|
the photon flux associated with a set of fitted parameters in a
|
|
128
|
-
[`
|
|
128
|
+
[`FitResult`][jaxspec.analysis.results.FitResult]
|
|
129
129
|
instead.
|
|
130
130
|
"""
|
|
131
131
|
|
|
@@ -151,9 +151,9 @@ class SpectralModel:
|
|
|
151
151
|
|
|
152
152
|
!!! info
|
|
153
153
|
This method is internally used in the inference process and should not be used directly. See
|
|
154
|
-
[`energy_flux`](/references/results/#jaxspec.analysis.results.
|
|
154
|
+
[`energy_flux`](/references/results/#jaxspec.analysis.results.FitResult.energy_flux) to compute
|
|
155
155
|
the energy flux associated with a set of fitted parameters in a
|
|
156
|
-
[`
|
|
156
|
+
[`FitResult`](/references/results/#jaxspec.analysis.results.FitResult)
|
|
157
157
|
instead.
|
|
158
158
|
"""
|
|
159
159
|
|
|
@@ -309,7 +309,7 @@ class SpectralModel:
|
|
|
309
309
|
if component.type == "additive":
|
|
310
310
|
|
|
311
311
|
def lam_func(e):
|
|
312
|
-
return component().continuum(e) + component().emission_lines(e, e + 1)[0]
|
|
312
|
+
return component(**kwargs).continuum(e) + component(**kwargs).emission_lines(e, e + 1)[0]
|
|
313
313
|
|
|
314
314
|
elif component.type == "multiplicative":
|
|
315
315
|
|
|
@@ -326,7 +326,7 @@ class SpectralModel:
|
|
|
326
326
|
"component_type": component.type,
|
|
327
327
|
"name": component.__name__.lower(),
|
|
328
328
|
"component": component,
|
|
329
|
-
"params": hk.transform(lam_func).init(None, jnp.ones(1)),
|
|
329
|
+
# "params": hk.transform(lam_func).init(None, jnp.ones(1)),
|
|
330
330
|
"fine_structure": False,
|
|
331
331
|
"kwargs": kwargs,
|
|
332
332
|
"depth": 0,
|
|
@@ -454,7 +454,7 @@ class ComponentMetaClass(type(hk.Module)):
|
|
|
454
454
|
syntax while style enabling the components to be used as haiku modules.
|
|
455
455
|
"""
|
|
456
456
|
|
|
457
|
-
def __call__(self, **kwargs):
|
|
457
|
+
def __call__(self, **kwargs) -> SpectralModel:
|
|
458
458
|
"""
|
|
459
459
|
This method enable to use model components as haiku modules when folded in a haiku transform
|
|
460
460
|
function and also to instantiate them as SpectralModel when out of a haiku transform
|
|
@@ -476,3 +476,51 @@ class ModelComponent(hk.Module, ABC, metaclass=ComponentMetaClass):
|
|
|
476
476
|
|
|
477
477
|
def __init__(self, *args, **kwargs):
|
|
478
478
|
super().__init__(*args, **kwargs)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class AdditiveComponent(ModelComponent, ABC):
|
|
482
|
+
type = "additive"
|
|
483
|
+
|
|
484
|
+
def continuum(self, energy):
|
|
485
|
+
"""
|
|
486
|
+
Method for computing the continuum associated to the model.
|
|
487
|
+
By default, this is set to 0, which means that the model has no continuum.
|
|
488
|
+
This should be overloaded by the user if the model has a continuum.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
return jnp.zeros_like(energy)
|
|
492
|
+
|
|
493
|
+
def emission_lines(self, e_min, e_max) -> (jax.Array, jax.Array):
|
|
494
|
+
"""
|
|
495
|
+
Method for computing the fine structure of an additive model between two energies.
|
|
496
|
+
By default, this is set to 0, which means that the model has no emission lines.
|
|
497
|
+
This should be overloaded by the user if the model has a fine structure.
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
return jnp.zeros_like(e_min), (e_min + e_max) / 2
|
|
501
|
+
|
|
502
|
+
'''
|
|
503
|
+
def integral(self, e_min, e_max):
|
|
504
|
+
r"""
|
|
505
|
+
Method for integrating an additive model between two energies. It relies on
|
|
506
|
+
double exponential quadrature for finite intervals to compute an approximation
|
|
507
|
+
of the integral of a model.
|
|
508
|
+
|
|
509
|
+
references
|
|
510
|
+
----------
|
|
511
|
+
* $Takahasi and Mori (1974) <https://ems.press/journals/prims/articles/2686>$_
|
|
512
|
+
* $Mori and Sugihara (2001) <https://doi.org/10.1016/S0377-0427(00)00501-X>$_
|
|
513
|
+
* $Tanh-sinh quadrature <https://en.wikipedia.org/wiki/Tanh-sinh_quadrature>$_ from Wikipedia
|
|
514
|
+
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
t = jnp.linspace(-4, 4, 71) # The number of points used is hardcoded and this is not ideal
|
|
518
|
+
# Quadrature nodes as defined in reference
|
|
519
|
+
phi = jnp.tanh(jnp.pi / 2 * jnp.sinh(t))
|
|
520
|
+
dphi = jnp.pi / 2 * jnp.cosh(t) * (1 / jnp.cosh(jnp.pi / 2 * jnp.sinh(t)) ** 2)
|
|
521
|
+
# Change of variable to turn the integral from E_min to E_max into an integral from -1 to 1
|
|
522
|
+
x = (e_max - e_min) / 2 * phi + (e_max + e_min) / 2
|
|
523
|
+
dx = (e_max - e_min) / 2 * dphi
|
|
524
|
+
|
|
525
|
+
return jnp.trapz(self(x) * dx, x=t)
|
|
526
|
+
'''
|
jaxspec/model/additive.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from abc import ABC
|
|
4
3
|
|
|
5
4
|
import haiku as hk
|
|
6
5
|
import jax
|
|
@@ -8,58 +7,10 @@ import jax.numpy as jnp
|
|
|
8
7
|
import jax.scipy as jsp
|
|
9
8
|
import astropy.units as u
|
|
10
9
|
import astropy.constants
|
|
11
|
-
|
|
12
|
-
from .abc import ModelComponent
|
|
13
10
|
from haiku.initializers import Constant as HaikuConstant
|
|
14
11
|
from ..util.integrate import integrate_interval
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class AdditiveComponent(ModelComponent, ABC):
|
|
18
|
-
type = "additive"
|
|
19
|
-
|
|
20
|
-
def continuum(self, energy):
|
|
21
|
-
"""
|
|
22
|
-
Method for computing the continuum associated to the model.
|
|
23
|
-
By default, this is set to 0, which means that the model has no continuum.
|
|
24
|
-
This should be overloaded by the user if the model has a continuum.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
return jnp.zeros_like(energy)
|
|
28
|
-
|
|
29
|
-
def emission_lines(self, e_min, e_max) -> (jax.Array, jax.Array):
|
|
30
|
-
"""
|
|
31
|
-
Method for computing the fine structure of an additive model between two energies.
|
|
32
|
-
By default, this is set to 0, which means that the model has no emission lines.
|
|
33
|
-
This should be overloaded by the user if the model has a fine structure.
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
return jnp.zeros_like(e_min), (e_min + e_max) / 2
|
|
37
|
-
|
|
38
|
-
'''
|
|
39
|
-
def integral(self, e_min, e_max):
|
|
40
|
-
r"""
|
|
41
|
-
Method for integrating an additive model between two energies. It relies on
|
|
42
|
-
double exponential quadrature for finite intervals to compute an approximation
|
|
43
|
-
of the integral of a model.
|
|
44
|
-
|
|
45
|
-
references
|
|
46
|
-
----------
|
|
47
|
-
* $Takahasi and Mori (1974) <https://ems.press/journals/prims/articles/2686>$_
|
|
48
|
-
* $Mori and Sugihara (2001) <https://doi.org/10.1016/S0377-0427(00)00501-X>$_
|
|
49
|
-
* $Tanh-sinh quadrature <https://en.wikipedia.org/wiki/Tanh-sinh_quadrature>$_ from Wikipedia
|
|
50
|
-
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
t = jnp.linspace(-4, 4, 71) # The number of points used is hardcoded and this is not ideal
|
|
54
|
-
# Quadrature nodes as defined in reference
|
|
55
|
-
phi = jnp.tanh(jnp.pi / 2 * jnp.sinh(t))
|
|
56
|
-
dphi = jnp.pi / 2 * jnp.cosh(t) * (1 / jnp.cosh(jnp.pi / 2 * jnp.sinh(t)) ** 2)
|
|
57
|
-
# Change of variable to turn the integral from E_min to E_max into an integral from -1 to 1
|
|
58
|
-
x = (e_max - e_min) / 2 * phi + (e_max + e_min) / 2
|
|
59
|
-
dx = (e_max - e_min) / 2 * dphi
|
|
60
|
-
|
|
61
|
-
return jnp.trapz(self(x) * dx, x=t)
|
|
62
|
-
'''
|
|
12
|
+
from .abc import AdditiveComponent
|
|
13
|
+
# from ._additive.apec import APEC # noqa: F401
|
|
63
14
|
|
|
64
15
|
|
|
65
16
|
class Powerlaw(AdditiveComponent):
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
Element angr aspl feld aneb grsa wilm lodd lgpp lgps
|
|
2
|
+
H 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00E+00 1.00E+00
|
|
3
|
+
He 9.77e-02 8.51e-02 9.77e-02 8.01e-02 8.51e-02 9.77e-02 7.92e-02 8.41E-02 9.69E-02
|
|
4
|
+
Li 1.45e-11 1.12e-11 1.26e-11 2.19e-09 1.26e-11 0.00 1.90e-09 1.26E-11 2.15E-09
|
|
5
|
+
Be 1.41e-11 2.40e-11 2.51e-11 2.87e-11 2.51e-11 0.00 2.57e-11 2.40E-11 2.36E-11
|
|
6
|
+
B 3.98e-10 5.01e-10 3.55e-10 8.82e-10 3.55e-10 0.00 6.03e-10 5.01E-10 7.26E-10
|
|
7
|
+
C 3.63e-04 2.69e-04 3.98e-04 4.45e-04 3.31e-04 2.40e-04 2.45e-04 2.45E-04 2.78E-04
|
|
8
|
+
N 1.12e-04 6.76e-05 1.00e-04 9.12e-05 8.32e-05 7.59e-05 6.76e-05 7.24E-05 8.19E-05
|
|
9
|
+
O 8.51e-04 4.90e-04 8.51e-04 7.39e-04 6.76e-04 4.90e-04 4.90e-04 5.37E-04 6.06E-04
|
|
10
|
+
F 3.63e-08 3.63e-08 3.63e-08 3.10e-08 3.63e-08 0.00 2.88e-08 3.63E-08 3.10E-08
|
|
11
|
+
Ne 1.23e-04 8.51e-05 1.29e-04 1.38e-04 1.20e-04 8.71e-05 7.41e-05 1.12E-04 1.27E-04
|
|
12
|
+
Na 2.14e-06 1.74e-06 2.14e-06 2.10e-06 2.14e-06 1.45e-06 1.99e-06 2.00E-06 2.23E-06
|
|
13
|
+
Mg 3.80e-05 3.98e-05 3.80e-05 3.95e-05 3.80e-05 2.51e-05 3.55e-05 3.47E-05 3.98E-05
|
|
14
|
+
Al 2.95e-06 2.82e-06 2.95e-06 3.12e-06 2.95e-06 2.14e-06 2.88e-06 2.95E-06 3.27E-06
|
|
15
|
+
Si 3.55e-05 3.24e-05 3.55e-05 3.68e-05 3.55e-05 1.86e-05 3.47e-05 3.31E-05 3.86E-05
|
|
16
|
+
P 2.82e-07 2.57e-07 2.82e-07 3.82e-07 2.82e-07 2.63e-07 2.88e-07 2.88E-07 3.20E-07
|
|
17
|
+
S 1.62e-05 1.32e-05 1.62e-05 1.89e-05 2.14e-05 1.23e-05 1.55e-05 1.38E-05 1.63E-05
|
|
18
|
+
Cl 3.16e-07 3.16e-07 3.16e-07 1.93e-07 3.16e-07 1.32e-07 1.82e-07 3.16E-07 2.00E-07
|
|
19
|
+
Ar 3.63e-06 2.51e-06 4.47e-06 3.82e-06 2.51e-06 2.57e-06 3.55e-06 3.16E-06 3.58E-06
|
|
20
|
+
K 1.32e-07 1.07e-07 1.32e-07 1.39e-07 1.32e-07 0.00 1.29e-07 1.32E-07 1.45E-07
|
|
21
|
+
Ca 2.29e-06 2.19e-06 2.29e-06 2.25e-06 2.29e-06 1.58e-06 2.19e-06 2.14E-06 2.33E-06
|
|
22
|
+
Sc 1.26e-09 1.41e-09 1.48e-09 1.24e-09 1.48e-09 0.00 1.17e-09 1.26E-09 1.33E-09
|
|
23
|
+
Ti 9.77e-08 8.91e-08 1.05e-07 8.82e-08 1.05e-07 6.46e-08 8.32e-08 7.94E-08 9.54E-08
|
|
24
|
+
V 1.00e-08 8.51e-09 1.00e-08 1.08e-08 1.00e-08 0.00 1.00e-08 1.00E-08 1.11E-08
|
|
25
|
+
Cr 4.68e-07 4.37e-07 4.68e-07 4.93e-07 4.68e-07 3.24e-07 4.47e-07 4.37E-07 5.06E-07
|
|
26
|
+
Mn 2.45e-07 2.69e-07 2.45e-07 3.50e-07 2.45e-07 2.19e-07 3.16e-07 2.34E-07 3.56E-07
|
|
27
|
+
Fe 4.68e-05 3.16e-05 3.24e-05 3.31e-05 3.16e-05 2.69e-05 2.95e-05 2.82E-05 3.27E-05
|
|
28
|
+
Co 8.32e-08 9.77e-08 8.32e-08 8.27e-08 8.32e-08 8.32e-08 8.13e-08 8.32E-08 9.07E-08
|
|
29
|
+
Ni 1.78e-06 1.66e-06 1.78e-06 1.81e-06 1.78e-06 1.12e-06 1.66e-06 1.70E-06 1.89E-06
|
|
30
|
+
Cu 1.62e-08 1.55e-08 1.62e-08 1.89e-08 1.62e-08 0.00 1.82e-08 1.62E-08 2.09E-08
|
|
31
|
+
Zn 3.98e-08 3.63e-08 3.98e-08 4.63e-08 3.98e-08 0.00 4.27e-08 4.17E-08 5.02E-08
|