jaxspec 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,377 @@
1
+ import jax.numpy as jnp
2
+ import jax
3
+ import haiku as hk
4
+ import astropy.units as u
5
+ from jax import lax
6
+ from jax.lax import scan, fori_loop
7
+ from jax.scipy.stats import norm as gaussian
8
+ from typing import Literal
9
+ from ...util.abundance import abundance_table, element_data
10
+ from haiku.initializers import Constant as HaikuConstant
11
+ from astropy.constants import c, m_p
12
+ from ..abc import AdditiveComponent
13
+ from .apec_loaders import get_temperature, get_continuum, get_pseudo, get_lines
14
+
15
+
16
+ @jax.jit
17
+ def lerp(x, x0, x1, y0, y1):
18
+ """
19
+ Linear interpolation routine
20
+ Return y(x) = (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
21
+ """
22
+ return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
23
+
24
+
25
+ @jax.jit
26
+ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end_index):
27
+ """
28
+ This function interpolate & integrate the values of a tabulated reference continuum between two energy limits
29
+ Sorry for the boilerplate here, but be sure that it works !
30
+
31
+ Parameters:
32
+ energy_low: lower limit of the integral
33
+ energy_high: upper limit of the integral
34
+ energy_ref: energy grid of the reference continuum
35
+ continuum_ref: continuum values evaluated at energy_ref
36
+
37
+ """
38
+ energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
39
+ start_index = jnp.searchsorted(energy_ref, energy_low, side="left") - 1
40
+ end_index = jnp.searchsorted(energy_ref, energy_high, side="left") + 1
41
+
42
+ def body_func(index, value):
43
+ integrated_flux, previous_energy, previous_continuum = value
44
+ current_energy, current_continuum = energy_ref[index], continuum_ref[index]
45
+
46
+ # 5 cases
47
+ # Neither current and previous energies are within the integral limits > nothing is added to the integrated flux
48
+ # The left limit of the integral is between the current and previous energy > previous energy is set to the limit, previous continuum is interpolated, and then added to the integrated flux
49
+ # The right limit of the integral is between the current and previous energy > current energy is set to the limit, current continuum is interpolated, and then added to the integrated flux
50
+ # Both current and previous energies are within the integral limits -> add to the integrated flux
51
+ # Within
52
+
53
+ current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
54
+ previous_energy_is_between = (energy_low <= previous_energy) * (previous_energy < energy_high)
55
+ energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
56
+
57
+ case = (
58
+ (1 - previous_energy_is_between) * current_energy_is_between * 1
59
+ + previous_energy_is_between * (1 - current_energy_is_between) * 2
60
+ + (previous_energy_is_between * current_energy_is_between) * 3
61
+ + energies_within_bins * 4
62
+ )
63
+
64
+ term_to_add = lax.switch(
65
+ case,
66
+ [
67
+ lambda pe, pc, ce, cc, el, er: 0.0, # 1
68
+ lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
69
+ lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
70
+ lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
71
+ lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc)) * (er - el) / 2,
72
+ # 5
73
+ ],
74
+ previous_energy,
75
+ previous_continuum,
76
+ current_energy,
77
+ current_continuum,
78
+ energy_low,
79
+ energy_high,
80
+ )
81
+
82
+ return integrated_flux + term_to_add, current_energy, current_continuum
83
+
84
+ integrated_flux, _, _ = fori_loop(start_index, end_index, body_func, (0.0, 0.0, 0.0))
85
+
86
+ return integrated_flux
87
+
88
+
89
+ def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
90
+ energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
91
+
92
+ return (jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)) / (e_high - e_low)
93
+
94
+
95
+ def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
96
+ """
97
+ Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
98
+ """
99
+
100
+ def scanned_func(carry, unpack):
101
+ e_low, e_high = unpack
102
+ if integrate:
103
+ continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
104
+ else:
105
+ continuum = interp(e_low, e_high, energy_ref, continuum_ref, end_index)
106
+
107
+ return carry, continuum
108
+
109
+ _, continuum = scan(scanned_func, 0.0, (energy[:-1], energy[1:]))
110
+
111
+ return continuum
112
+
113
+
114
+ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances, integrate=True):
115
+ """
116
+ Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
117
+ and weight the flux depending on the abundance of each element
118
+ """
119
+
120
+ def scanned_func(_, unpack):
121
+ energy_ref, continuum_ref, end_idx = unpack
122
+ element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx, integrate=integrate)
123
+
124
+ return _, element_flux
125
+
126
+ _, flux = scan(scanned_func, 0.0, (energy_ref, continuum_ref, end_index))
127
+
128
+ return abundances @ flux
129
+
130
+
131
+ @jax.jit
132
+ def get_lines_contribution_broadening(
133
+ line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
134
+ ):
135
+ def body_func(i, flux):
136
+ # Notice the -1 in line element to match the 0-based indexing
137
+ l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
138
+ broadening = l_energy * total_broadening[l_element]
139
+ l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
140
+ l_flux = l_flux * l_emissivity * abundances[l_element]
141
+
142
+ return flux + l_flux
143
+
144
+ return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
145
+
146
+
147
+ @jax.jit
148
+ def get_lines_contribution_broadening_derivative(
149
+ line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
150
+ ):
151
+ def body_func(i, flux):
152
+ # Notice the -1 in line element to match the 0-based indexing
153
+ l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
154
+ broadening = l_energy * total_broadening[l_element]
155
+ l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
156
+ l_flux = l_flux * l_emissivity * abundances[l_element]
157
+
158
+ return flux + l_flux
159
+
160
+ return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
161
+
162
+
163
+ @jax.custom_jvp
164
+ @jax.jit
165
+ def continuum_func(energy, kT, abundances):
166
+ idx, kT_low, kT_high = get_temperature(kT)
167
+ continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
168
+ continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
169
+
170
+ return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
171
+
172
+
173
+ @jax.jit
174
+ @continuum_func.defjvp
175
+ def continuum_jvp(primals, tangents):
176
+ energy, kT, abundances = primals
177
+ energy_dot, kT_dot, abundances_dot = tangents
178
+
179
+ idx, kT_low, kT_high = get_temperature(kT)
180
+ continuum_low_pars = get_continuum(idx)
181
+ continuum_high_pars = get_continuum(idx + 1)
182
+
183
+ # Energy derivative
184
+ dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
185
+ dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
186
+ energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
187
+
188
+ # Temperature derivative
189
+ continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
190
+ continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
191
+ kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
192
+
193
+ # Abundances derivative
194
+ dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
195
+ dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
196
+ abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
197
+
198
+ primals_out = continuum_func(*primals)
199
+
200
+ return primals_out, energy_derivative + kT_derivative + abundances_derivative
201
+
202
+
203
+ @jax.custom_jvp
204
+ @jax.jit
205
+ def pseudo_func(energy, kT, abundances):
206
+ idx, kT_low, kT_high = get_temperature(kT)
207
+ continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
208
+ continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
209
+
210
+ return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
211
+
212
+
213
+ @jax.jit
214
+ @pseudo_func.defjvp
215
+ def pseudo_jvp(primals, tangents):
216
+ energy, kT, abundances = primals
217
+ energy_dot, kT_dot, abundances_dot = tangents
218
+
219
+ idx, kT_low, kT_high = get_temperature(kT)
220
+ continuum_low_pars = get_pseudo(idx)
221
+ continuum_high_pars = get_pseudo(idx + 1)
222
+
223
+ # Energy derivative
224
+ dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
225
+ dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
226
+ energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
227
+
228
+ # Temperature derivative
229
+ continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
230
+ continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
231
+ kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
232
+
233
+ # Abundances derivative
234
+ dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
235
+ dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
236
+ abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
237
+
238
+ primals_out = pseudo_func(*primals)
239
+
240
+ return primals_out, energy_derivative + kT_derivative + abundances_derivative
241
+
242
+
243
+ @jax.custom_jvp
244
+ @jax.jit
245
+ def lines_func(energy, kT, abundances, broadening):
246
+ idx, kT_low, kT_high = get_temperature(kT)
247
+ line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
248
+ line_high = get_lines_contribution_broadening(*get_lines(idx + 1), energy, abundances, broadening)
249
+
250
+ return lerp(kT, kT_low, kT_high, line_low, line_high)
251
+
252
+
253
+ @jax.jit
254
+ @lines_func.defjvp
255
+ def lines_jvp(primals, tangents):
256
+ energy, kT, abundances, broadening = primals
257
+ energy_dot, kT_dot, abundances_dot, broadening_dot = tangents
258
+
259
+ primals_out = lines_func(*primals)
260
+ return primals_out, jnp.zeros_like(primals_out)
261
+
262
+
263
+ class APEC(AdditiveComponent):
264
+ """
265
+ APEC model implementation in pure JAX for X-ray spectral fitting.
266
+
267
+ !!! warning
268
+ This implementation is optimised for the CPU, it shows poor performance on the GPU.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ continuum: bool = True,
274
+ pseudo: bool = True,
275
+ lines: bool = True,
276
+ thermal_broadening: bool = True,
277
+ turbulent_broadening: bool = True,
278
+ variant: Literal["none", "v", "vv"] = "none",
279
+ abundance_table: Literal["angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"] = "angr",
280
+ trace_abundance: float = 1.0,
281
+ **kwargs,
282
+ ):
283
+ super(APEC, self).__init__(**kwargs)
284
+
285
+ self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
286
+
287
+ self.abundance_table = abundance_table
288
+ self.thermal_broadening = thermal_broadening
289
+ self.turbulent_broadening = turbulent_broadening
290
+ self.continuum_to_compute = continuum
291
+ self.pseudo_to_compute = pseudo
292
+ self.lines_to_compute = lines
293
+ self.trace_abundance = trace_abundance
294
+ self.variant = variant
295
+
296
+ def get_thermal_broadening(self):
297
+ r"""
298
+ Compute the thermal broadening $\sigma_T$ for each element using :
299
+
300
+ $$ \frac{\sigma_T}{E_{\text{line}}} = \frac{1}{c}\sqrt{\frac{k_{B} T}{A m_p}}$$
301
+
302
+ where $E_{\text{line}}$ is the energy of the line, $c$ is the speed of light, $k_{B}$ is the Boltzmann constant,
303
+ $T$ is the temperature, $A$ is the atomic weight of the element and $m_p$ is the proton mass.
304
+ """
305
+
306
+ if self.thermal_broadening:
307
+ kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
308
+ factor = 1 / c * (1 / m_p) ** (1 / 2)
309
+ factor = factor.to(u.keV ** (-1 / 2)).value
310
+
311
+ # Multiply this factor by Line_Energy * sqrt(kT/A) to get the broadening for a line
312
+ # This return value must be multiplied by the energy of the line to get actual broadening
313
+ return factor * jnp.sqrt(kT / self.atomic_weights)
314
+
315
+ else:
316
+ return jnp.zeros((30,))
317
+
318
+ def get_turbulent_broadening(self):
319
+ r"""
320
+ Return the turbulent broadening using :
321
+
322
+ $$\frac{\sigma_\text{turb}}{E_{\text{line}}} = \frac{\sigma_{v ~ ||}}{c}$$
323
+
324
+ where $\sigma_{v ~ ||}$ is the velocity dispersion along the line of sight in km/s.
325
+ """
326
+ if self.turbulent_broadening:
327
+ # This return value must be multiplied by the energy of the line to get actual broadening
328
+ return hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
329
+ else:
330
+ return 0.0
331
+
332
+ def get_parameters(self):
333
+ none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
334
+ v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
335
+ trace_elements = jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
336
+
337
+ # Set abundances of trace element (will be overwritten in the vv case)
338
+ abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
339
+
340
+ if self.variant == "vv":
341
+ for i, element in enumerate(abundance_table["Element"]):
342
+ if element != "H":
343
+ abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
344
+
345
+ elif self.variant == "v":
346
+ for i, element in enumerate(abundance_table["Element"]):
347
+ if element != "H" and element in v_elements:
348
+ abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
349
+
350
+ else:
351
+ Z = hk.get_parameter("Abundance", [], init=HaikuConstant(1.0))
352
+ for i, element in enumerate(abundance_table["Element"]):
353
+ if element != "H" and element in none_elements:
354
+ abund = abund.at[i].set(Z)
355
+
356
+ if abund != "angr":
357
+ abund = abund * jnp.asarray(abundance_table[self.abundance_table] / abundance_table["angr"])
358
+
359
+ # Set the temperature, redshift, normalisation
360
+ kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
361
+ z = hk.get_parameter("Redshift", [], init=HaikuConstant(0.0))
362
+ norm = hk.get_parameter("norm", [], init=HaikuConstant(1.0))
363
+
364
+ return kT, z, norm, abund
365
+
366
+ def emission_lines(self, e_low, e_high):
367
+ # Get the parameters and extract the relevant data
368
+ energy = jnp.hstack([e_low, e_high[-1]])
369
+ kT, z, norm, abundances = self.get_parameters()
370
+ total_broadening = jnp.hypot(self.get_thermal_broadening(), self.get_turbulent_broadening())
371
+ energy = energy * (1 + z)
372
+
373
+ continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
374
+ pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
375
+ lines = lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
376
+
377
+ return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
@@ -0,0 +1,90 @@
1
+ """ This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
+ pure callback to enable reading data from the files without saturating the memory. """
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ import importlib.resources
8
+ import xarray as xr
9
+
10
+
11
+ apec_file = xr.open_dataset(importlib.resources.files("jaxspec") / "tables/apec.nc", engine="h5netcdf")
12
+
13
+
14
+ def temperature_table_getter(kT):
15
+ idx = np.searchsorted(apec_file.temperature, kT) - 1
16
+
17
+ if idx > len(apec_file.temperature) - 2:
18
+ return idx, 0.0, 0.0
19
+ else:
20
+ return idx, float(apec_file.temperature[idx]), float(apec_file.temperature[idx + 1])
21
+
22
+
23
+ def continuum_table_getter(idx):
24
+ if idx > len(apec_file.temperature) - 2:
25
+ continuum_energy_array = jnp.zeros_like(apec_file.continuum_energy[0])
26
+ continuum_emissivity_array = jnp.zeros_like(apec_file.continuum_emissivity[0])
27
+ end_index_continuum = jnp.zeros_like(apec_file.continuum_end_index[0])
28
+
29
+ else:
30
+ continuum_energy_array = jnp.asarray(apec_file.continuum_energy[idx])
31
+ continuum_emissivity_array = jnp.asarray(apec_file.continuum_emissivity[idx])
32
+ end_index_continuum = jnp.asarray(apec_file.continuum_end_index[idx])
33
+
34
+ return continuum_energy_array, continuum_emissivity_array, end_index_continuum
35
+
36
+
37
+ def pseudo_table_getter(idx):
38
+ if idx > len(apec_file.temperature) - 2:
39
+ pseudo_energy_array = jnp.zeros_like(apec_file.pseudo_energy[0])
40
+ pseudo_emissivity_array = jnp.zeros_like(apec_file.pseudo_emissivity[0])
41
+ end_index_pseudo = jnp.zeros_like(apec_file.pseudo_end_index[0])
42
+
43
+ else:
44
+ pseudo_energy_array = jnp.asarray(apec_file.pseudo_energy[idx])
45
+ pseudo_emissivity_array = jnp.asarray(apec_file.pseudo_emissivity[idx])
46
+ end_index_pseudo = jnp.asarray(apec_file.pseudo_end_index[idx])
47
+
48
+ return pseudo_energy_array, pseudo_emissivity_array, end_index_pseudo
49
+
50
+
51
+ def lines_table_getter(idx):
52
+ if idx > len(apec_file.temperature) - 2:
53
+ line_energy_array = jnp.zeros_like(apec_file.line_energy[0])
54
+ line_element_array = jnp.zeros_like(apec_file.line_element[0])
55
+ line_emissivity_array = jnp.zeros_like(apec_file.line_emissivity[0])
56
+ end_index_lines = jnp.zeros_like(apec_file.line_end_index[0])
57
+
58
+ else:
59
+ line_energy_array = jnp.asarray(apec_file.line_energy[idx])
60
+ line_element_array = jnp.asarray(apec_file.line_element[idx])
61
+ line_emissivity_array = jnp.asarray(apec_file.line_emissivity[idx])
62
+ end_index_lines = jnp.asarray(apec_file.line_end_index[idx])
63
+
64
+ return line_energy_array, line_element_array, line_emissivity_array, end_index_lines
65
+
66
+
67
+ pure_callback_temperature_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, temperature_table_getter(10.0)))
68
+ pure_callback_continuum_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, continuum_table_getter(0)))
69
+ pure_callback_pseudo_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, pseudo_table_getter(0)))
70
+ pure_callback_line_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, lines_table_getter(0)))
71
+
72
+
73
+ @jax.jit
74
+ def get_temperature(kT):
75
+ return jax.pure_callback(temperature_table_getter, pure_callback_temperature_shape, kT)
76
+
77
+
78
+ @jax.jit
79
+ def get_continuum(idx):
80
+ return jax.pure_callback(continuum_table_getter, pure_callback_continuum_shape, idx)
81
+
82
+
83
+ @jax.jit
84
+ def get_pseudo(idx):
85
+ return jax.pure_callback(pseudo_table_getter, pure_callback_pseudo_shape, idx)
86
+
87
+
88
+ @jax.jit
89
+ def get_lines(idx):
90
+ return jax.pure_callback(lines_table_getter, pure_callback_line_shape, idx)
jaxspec/model/abc.py CHANGED
@@ -123,9 +123,9 @@ class SpectralModel:
123
123
 
124
124
  !!! info
125
125
  This method is internally used in the inference process and should not be used directly. See
126
- [`photon_flux`][jaxspec.analysis.results.ChainResult.photon_flux] to compute
126
+ [`photon_flux`][jaxspec.analysis.results.FitResult.photon_flux] to compute
127
127
  the photon flux associated with a set of fitted parameters in a
128
- [`ChainResult`][jaxspec.analysis.results.ChainResult]
128
+ [`FitResult`][jaxspec.analysis.results.FitResult]
129
129
  instead.
130
130
  """
131
131
 
@@ -151,9 +151,9 @@ class SpectralModel:
151
151
 
152
152
  !!! info
153
153
  This method is internally used in the inference process and should not be used directly. See
154
- [`energy_flux`](/references/results/#jaxspec.analysis.results.ChainResult.energy_flux) to compute
154
+ [`energy_flux`](/references/results/#jaxspec.analysis.results.FitResult.energy_flux) to compute
155
155
  the energy flux associated with a set of fitted parameters in a
156
- [`ChainResult`](/references/results/#jaxspec.analysis.results.ChainResult)
156
+ [`FitResult`](/references/results/#jaxspec.analysis.results.FitResult)
157
157
  instead.
158
158
  """
159
159
 
@@ -309,7 +309,7 @@ class SpectralModel:
309
309
  if component.type == "additive":
310
310
 
311
311
  def lam_func(e):
312
- return component().continuum(e) + component().emission_lines(e, e + 1)[0]
312
+ return component(**kwargs).continuum(e) + component(**kwargs).emission_lines(e, e + 1)[0]
313
313
 
314
314
  elif component.type == "multiplicative":
315
315
 
@@ -326,7 +326,7 @@ class SpectralModel:
326
326
  "component_type": component.type,
327
327
  "name": component.__name__.lower(),
328
328
  "component": component,
329
- "params": hk.transform(lam_func).init(None, jnp.ones(1)),
329
+ # "params": hk.transform(lam_func).init(None, jnp.ones(1)),
330
330
  "fine_structure": False,
331
331
  "kwargs": kwargs,
332
332
  "depth": 0,
@@ -454,7 +454,7 @@ class ComponentMetaClass(type(hk.Module)):
454
454
  syntax while style enabling the components to be used as haiku modules.
455
455
  """
456
456
 
457
- def __call__(self, **kwargs):
457
+ def __call__(self, **kwargs) -> SpectralModel:
458
458
  """
459
459
  This method enable to use model components as haiku modules when folded in a haiku transform
460
460
  function and also to instantiate them as SpectralModel when out of a haiku transform
@@ -476,3 +476,51 @@ class ModelComponent(hk.Module, ABC, metaclass=ComponentMetaClass):
476
476
 
477
477
  def __init__(self, *args, **kwargs):
478
478
  super().__init__(*args, **kwargs)
479
+
480
+
481
+ class AdditiveComponent(ModelComponent, ABC):
482
+ type = "additive"
483
+
484
+ def continuum(self, energy):
485
+ """
486
+ Method for computing the continuum associated to the model.
487
+ By default, this is set to 0, which means that the model has no continuum.
488
+ This should be overloaded by the user if the model has a continuum.
489
+ """
490
+
491
+ return jnp.zeros_like(energy)
492
+
493
+ def emission_lines(self, e_min, e_max) -> (jax.Array, jax.Array):
494
+ """
495
+ Method for computing the fine structure of an additive model between two energies.
496
+ By default, this is set to 0, which means that the model has no emission lines.
497
+ This should be overloaded by the user if the model has a fine structure.
498
+ """
499
+
500
+ return jnp.zeros_like(e_min), (e_min + e_max) / 2
501
+
502
+ '''
503
+ def integral(self, e_min, e_max):
504
+ r"""
505
+ Method for integrating an additive model between two energies. It relies on
506
+ double exponential quadrature for finite intervals to compute an approximation
507
+ of the integral of a model.
508
+
509
+ references
510
+ ----------
511
+ * $Takahasi and Mori (1974) <https://ems.press/journals/prims/articles/2686>$_
512
+ * $Mori and Sugihara (2001) <https://doi.org/10.1016/S0377-0427(00)00501-X>$_
513
+ * $Tanh-sinh quadrature <https://en.wikipedia.org/wiki/Tanh-sinh_quadrature>$_ from Wikipedia
514
+
515
+ """
516
+
517
+ t = jnp.linspace(-4, 4, 71) # The number of points used is hardcoded and this is not ideal
518
+ # Quadrature nodes as defined in reference
519
+ phi = jnp.tanh(jnp.pi / 2 * jnp.sinh(t))
520
+ dphi = jnp.pi / 2 * jnp.cosh(t) * (1 / jnp.cosh(jnp.pi / 2 * jnp.sinh(t)) ** 2)
521
+ # Change of variable to turn the integral from E_min to E_max into an integral from -1 to 1
522
+ x = (e_max - e_min) / 2 * phi + (e_max + e_min) / 2
523
+ dx = (e_max - e_min) / 2 * dphi
524
+
525
+ return jnp.trapz(self(x) * dx, x=t)
526
+ '''
jaxspec/model/additive.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from abc import ABC
4
3
 
5
4
  import haiku as hk
6
5
  import jax
@@ -8,58 +7,10 @@ import jax.numpy as jnp
8
7
  import jax.scipy as jsp
9
8
  import astropy.units as u
10
9
  import astropy.constants
11
-
12
- from .abc import ModelComponent
13
10
  from haiku.initializers import Constant as HaikuConstant
14
11
  from ..util.integrate import integrate_interval
15
-
16
-
17
- class AdditiveComponent(ModelComponent, ABC):
18
- type = "additive"
19
-
20
- def continuum(self, energy):
21
- """
22
- Method for computing the continuum associated to the model.
23
- By default, this is set to 0, which means that the model has no continuum.
24
- This should be overloaded by the user if the model has a continuum.
25
- """
26
-
27
- return jnp.zeros_like(energy)
28
-
29
- def emission_lines(self, e_min, e_max) -> (jax.Array, jax.Array):
30
- """
31
- Method for computing the fine structure of an additive model between two energies.
32
- By default, this is set to 0, which means that the model has no emission lines.
33
- This should be overloaded by the user if the model has a fine structure.
34
- """
35
-
36
- return jnp.zeros_like(e_min), (e_min + e_max) / 2
37
-
38
- '''
39
- def integral(self, e_min, e_max):
40
- r"""
41
- Method for integrating an additive model between two energies. It relies on
42
- double exponential quadrature for finite intervals to compute an approximation
43
- of the integral of a model.
44
-
45
- references
46
- ----------
47
- * $Takahasi and Mori (1974) <https://ems.press/journals/prims/articles/2686>$_
48
- * $Mori and Sugihara (2001) <https://doi.org/10.1016/S0377-0427(00)00501-X>$_
49
- * $Tanh-sinh quadrature <https://en.wikipedia.org/wiki/Tanh-sinh_quadrature>$_ from Wikipedia
50
-
51
- """
52
-
53
- t = jnp.linspace(-4, 4, 71) # The number of points used is hardcoded and this is not ideal
54
- # Quadrature nodes as defined in reference
55
- phi = jnp.tanh(jnp.pi / 2 * jnp.sinh(t))
56
- dphi = jnp.pi / 2 * jnp.cosh(t) * (1 / jnp.cosh(jnp.pi / 2 * jnp.sinh(t)) ** 2)
57
- # Change of variable to turn the integral from E_min to E_max into an integral from -1 to 1
58
- x = (e_max - e_min) / 2 * phi + (e_max + e_min) / 2
59
- dx = (e_max - e_min) / 2 * dphi
60
-
61
- return jnp.trapz(self(x) * dx, x=t)
62
- '''
12
+ from .abc import AdditiveComponent
13
+ # from ._additive.apec import APEC # noqa: F401
63
14
 
64
15
 
65
16
  class Powerlaw(AdditiveComponent):
@@ -0,0 +1,31 @@
1
+ Element angr aspl feld aneb grsa wilm lodd lgpp lgps
2
+ H 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00e+00 1.00E+00 1.00E+00
3
+ He 9.77e-02 8.51e-02 9.77e-02 8.01e-02 8.51e-02 9.77e-02 7.92e-02 8.41E-02 9.69E-02
4
+ Li 1.45e-11 1.12e-11 1.26e-11 2.19e-09 1.26e-11 0.00 1.90e-09 1.26E-11 2.15E-09
5
+ Be 1.41e-11 2.40e-11 2.51e-11 2.87e-11 2.51e-11 0.00 2.57e-11 2.40E-11 2.36E-11
6
+ B 3.98e-10 5.01e-10 3.55e-10 8.82e-10 3.55e-10 0.00 6.03e-10 5.01E-10 7.26E-10
7
+ C 3.63e-04 2.69e-04 3.98e-04 4.45e-04 3.31e-04 2.40e-04 2.45e-04 2.45E-04 2.78E-04
8
+ N 1.12e-04 6.76e-05 1.00e-04 9.12e-05 8.32e-05 7.59e-05 6.76e-05 7.24E-05 8.19E-05
9
+ O 8.51e-04 4.90e-04 8.51e-04 7.39e-04 6.76e-04 4.90e-04 4.90e-04 5.37E-04 6.06E-04
10
+ F 3.63e-08 3.63e-08 3.63e-08 3.10e-08 3.63e-08 0.00 2.88e-08 3.63E-08 3.10E-08
11
+ Ne 1.23e-04 8.51e-05 1.29e-04 1.38e-04 1.20e-04 8.71e-05 7.41e-05 1.12E-04 1.27E-04
12
+ Na 2.14e-06 1.74e-06 2.14e-06 2.10e-06 2.14e-06 1.45e-06 1.99e-06 2.00E-06 2.23E-06
13
+ Mg 3.80e-05 3.98e-05 3.80e-05 3.95e-05 3.80e-05 2.51e-05 3.55e-05 3.47E-05 3.98E-05
14
+ Al 2.95e-06 2.82e-06 2.95e-06 3.12e-06 2.95e-06 2.14e-06 2.88e-06 2.95E-06 3.27E-06
15
+ Si 3.55e-05 3.24e-05 3.55e-05 3.68e-05 3.55e-05 1.86e-05 3.47e-05 3.31E-05 3.86E-05
16
+ P 2.82e-07 2.57e-07 2.82e-07 3.82e-07 2.82e-07 2.63e-07 2.88e-07 2.88E-07 3.20E-07
17
+ S 1.62e-05 1.32e-05 1.62e-05 1.89e-05 2.14e-05 1.23e-05 1.55e-05 1.38E-05 1.63E-05
18
+ Cl 3.16e-07 3.16e-07 3.16e-07 1.93e-07 3.16e-07 1.32e-07 1.82e-07 3.16E-07 2.00E-07
19
+ Ar 3.63e-06 2.51e-06 4.47e-06 3.82e-06 2.51e-06 2.57e-06 3.55e-06 3.16E-06 3.58E-06
20
+ K 1.32e-07 1.07e-07 1.32e-07 1.39e-07 1.32e-07 0.00 1.29e-07 1.32E-07 1.45E-07
21
+ Ca 2.29e-06 2.19e-06 2.29e-06 2.25e-06 2.29e-06 1.58e-06 2.19e-06 2.14E-06 2.33E-06
22
+ Sc 1.26e-09 1.41e-09 1.48e-09 1.24e-09 1.48e-09 0.00 1.17e-09 1.26E-09 1.33E-09
23
+ Ti 9.77e-08 8.91e-08 1.05e-07 8.82e-08 1.05e-07 6.46e-08 8.32e-08 7.94E-08 9.54E-08
24
+ V 1.00e-08 8.51e-09 1.00e-08 1.08e-08 1.00e-08 0.00 1.00e-08 1.00E-08 1.11E-08
25
+ Cr 4.68e-07 4.37e-07 4.68e-07 4.93e-07 4.68e-07 3.24e-07 4.47e-07 4.37E-07 5.06E-07
26
+ Mn 2.45e-07 2.69e-07 2.45e-07 3.50e-07 2.45e-07 2.19e-07 3.16e-07 2.34E-07 3.56E-07
27
+ Fe 4.68e-05 3.16e-05 3.24e-05 3.31e-05 3.16e-05 2.69e-05 2.95e-05 2.82E-05 3.27E-05
28
+ Co 8.32e-08 9.77e-08 8.32e-08 8.27e-08 8.32e-08 8.32e-08 8.13e-08 8.32E-08 9.07E-08
29
+ Ni 1.78e-06 1.66e-06 1.78e-06 1.81e-06 1.78e-06 1.12e-06 1.66e-06 1.70E-06 1.89E-06
30
+ Cu 1.62e-08 1.55e-08 1.62e-08 1.89e-08 1.62e-08 0.00 1.82e-08 1.62E-08 2.09E-08
31
+ Zn 3.98e-08 3.63e-08 3.98e-08 4.63e-08 3.98e-08 0.00 4.27e-08 4.17E-08 5.02E-08