jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +297 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +114 -84
- jaxspec/fit.py +335 -96
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +194 -55
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +63 -41
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +28 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/METADATA +14 -10
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/RECORD +19 -25
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/WHEEL +1 -1
- jaxspec/data/example_data/MOS1.pha +0 -46
- jaxspec/data/example_data/MOS2.pha +0 -42
- jaxspec/data/example_data/PN.pha +1 -293
- jaxspec/data/example_data/fakeit.pha +1 -335
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/LICENSE.md +0 -0
jaxspec/model/_additive/apec.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
4
5
|
import astropy.units as u
|
|
6
|
+
import haiku as hk
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
|
|
10
|
+
from astropy.constants import c, m_p
|
|
11
|
+
from haiku.initializers import Constant as HaikuConstant
|
|
5
12
|
from jax import lax
|
|
6
|
-
from jax.lax import
|
|
13
|
+
from jax.lax import fori_loop, scan
|
|
7
14
|
from jax.scipy.stats import norm as gaussian
|
|
8
|
-
|
|
15
|
+
|
|
9
16
|
from ...util.abundance import abundance_table, element_data
|
|
10
|
-
from haiku.initializers import Constant as HaikuConstant
|
|
11
|
-
from astropy.constants import c, m_p
|
|
12
17
|
from ..abc import AdditiveComponent
|
|
13
|
-
from .apec_loaders import
|
|
18
|
+
from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
@jax.jit
|
|
@@ -51,7 +56,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
|
|
|
51
56
|
# Within
|
|
52
57
|
|
|
53
58
|
current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
|
|
54
|
-
previous_energy_is_between = (energy_low <= previous_energy) * (
|
|
59
|
+
previous_energy_is_between = (energy_low <= previous_energy) * (
|
|
60
|
+
previous_energy < energy_high
|
|
61
|
+
)
|
|
55
62
|
energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
|
|
56
63
|
|
|
57
64
|
case = (
|
|
@@ -68,7 +75,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
|
|
|
68
75
|
lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
|
|
69
76
|
lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
|
|
70
77
|
lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
|
|
71
|
-
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
|
|
78
|
+
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
|
|
79
|
+
* (er - el)
|
|
80
|
+
/ 2,
|
|
72
81
|
# 5
|
|
73
82
|
],
|
|
74
83
|
previous_energy,
|
|
@@ -86,23 +95,24 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
|
|
|
86
95
|
return integrated_flux
|
|
87
96
|
|
|
88
97
|
|
|
98
|
+
@jax.jit
|
|
89
99
|
def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
|
|
90
100
|
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
91
101
|
|
|
92
|
-
return (
|
|
102
|
+
return (
|
|
103
|
+
jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
|
|
104
|
+
) / (e_high - e_low)
|
|
93
105
|
|
|
94
106
|
|
|
95
|
-
|
|
107
|
+
@jax.jit
|
|
108
|
+
def interp_flux(energy, energy_ref, continuum_ref, end_index):
|
|
96
109
|
"""
|
|
97
110
|
Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
|
|
98
111
|
"""
|
|
99
112
|
|
|
100
113
|
def scanned_func(carry, unpack):
|
|
101
114
|
e_low, e_high = unpack
|
|
102
|
-
|
|
103
|
-
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
104
|
-
else:
|
|
105
|
-
continuum = interp(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
115
|
+
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
106
116
|
|
|
107
117
|
return carry, continuum
|
|
108
118
|
|
|
@@ -111,7 +121,8 @@ def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
|
|
|
111
121
|
return continuum
|
|
112
122
|
|
|
113
123
|
|
|
114
|
-
|
|
124
|
+
@jax.jit
|
|
125
|
+
def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
|
|
115
126
|
"""
|
|
116
127
|
Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
|
|
117
128
|
and weight the flux depending on the abundance of each element
|
|
@@ -119,7 +130,7 @@ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundance
|
|
|
119
130
|
|
|
120
131
|
def scanned_func(_, unpack):
|
|
121
132
|
energy_ref, continuum_ref, end_idx = unpack
|
|
122
|
-
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx
|
|
133
|
+
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
|
|
123
134
|
|
|
124
135
|
return _, element_flux
|
|
125
136
|
|
|
@@ -136,23 +147,9 @@ def get_lines_contribution_broadening(
|
|
|
136
147
|
# Notice the -1 in line element to match the 0-based indexing
|
|
137
148
|
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
138
149
|
broadening = l_energy * total_broadening[l_element]
|
|
139
|
-
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
return flux + l_flux
|
|
143
|
-
|
|
144
|
-
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@jax.jit
|
|
148
|
-
def get_lines_contribution_broadening_derivative(
|
|
149
|
-
line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
|
|
150
|
-
):
|
|
151
|
-
def body_func(i, flux):
|
|
152
|
-
# Notice the -1 in line element to match the 0-based indexing
|
|
153
|
-
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
154
|
-
broadening = l_energy * total_broadening[l_element]
|
|
155
|
-
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
|
|
150
|
+
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
|
|
151
|
+
energy[:-1], l_energy, broadening
|
|
152
|
+
)
|
|
156
153
|
l_flux = l_flux * l_emissivity * abundances[l_element]
|
|
157
154
|
|
|
158
155
|
return flux + l_flux
|
|
@@ -160,7 +157,6 @@ def get_lines_contribution_broadening_derivative(
|
|
|
160
157
|
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
161
158
|
|
|
162
159
|
|
|
163
|
-
@jax.custom_jvp
|
|
164
160
|
@jax.jit
|
|
165
161
|
def continuum_func(energy, kT, abundances):
|
|
166
162
|
idx, kT_low, kT_high = get_temperature(kT)
|
|
@@ -170,96 +166,27 @@ def continuum_func(energy, kT, abundances):
|
|
|
170
166
|
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
171
167
|
|
|
172
168
|
|
|
173
|
-
@jax.jit
|
|
174
|
-
@continuum_func.defjvp
|
|
175
|
-
def continuum_jvp(primals, tangents):
|
|
176
|
-
energy, kT, abundances = primals
|
|
177
|
-
energy_dot, kT_dot, abundances_dot = tangents
|
|
178
|
-
|
|
179
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
180
|
-
continuum_low_pars = get_continuum(idx)
|
|
181
|
-
continuum_high_pars = get_continuum(idx + 1)
|
|
182
|
-
|
|
183
|
-
# Energy derivative
|
|
184
|
-
dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
|
|
185
|
-
dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
|
|
186
|
-
energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
|
|
187
|
-
|
|
188
|
-
# Temperature derivative
|
|
189
|
-
continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
|
|
190
|
-
continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
|
|
191
|
-
kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
|
|
192
|
-
|
|
193
|
-
# Abundances derivative
|
|
194
|
-
dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
|
|
195
|
-
dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
|
|
196
|
-
abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
|
|
197
|
-
|
|
198
|
-
primals_out = continuum_func(*primals)
|
|
199
|
-
|
|
200
|
-
return primals_out, energy_derivative + kT_derivative + abundances_derivative
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
@jax.custom_jvp
|
|
204
169
|
@jax.jit
|
|
205
170
|
def pseudo_func(energy, kT, abundances):
|
|
206
171
|
idx, kT_low, kT_high = get_temperature(kT)
|
|
207
|
-
continuum_low = interp_flux_elements(*
|
|
208
|
-
continuum_high = interp_flux_elements(*
|
|
172
|
+
continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
|
|
173
|
+
continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
|
|
209
174
|
|
|
210
175
|
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
211
176
|
|
|
212
177
|
|
|
213
|
-
@jax.
|
|
214
|
-
@pseudo_func.defjvp
|
|
215
|
-
def pseudo_jvp(primals, tangents):
|
|
216
|
-
energy, kT, abundances = primals
|
|
217
|
-
energy_dot, kT_dot, abundances_dot = tangents
|
|
218
|
-
|
|
219
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
220
|
-
continuum_low_pars = get_pseudo(idx)
|
|
221
|
-
continuum_high_pars = get_pseudo(idx + 1)
|
|
222
|
-
|
|
223
|
-
# Energy derivative
|
|
224
|
-
dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
|
|
225
|
-
dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
|
|
226
|
-
energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
|
|
227
|
-
|
|
228
|
-
# Temperature derivative
|
|
229
|
-
continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
|
|
230
|
-
continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
|
|
231
|
-
kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
|
|
232
|
-
|
|
233
|
-
# Abundances derivative
|
|
234
|
-
dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
|
|
235
|
-
dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
|
|
236
|
-
abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
|
|
237
|
-
|
|
238
|
-
primals_out = pseudo_func(*primals)
|
|
239
|
-
|
|
240
|
-
return primals_out, energy_derivative + kT_derivative + abundances_derivative
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
@jax.custom_jvp
|
|
178
|
+
# @jax.custom_jvp
|
|
244
179
|
@jax.jit
|
|
245
180
|
def lines_func(energy, kT, abundances, broadening):
|
|
246
181
|
idx, kT_low, kT_high = get_temperature(kT)
|
|
247
182
|
line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
|
|
248
|
-
line_high = get_lines_contribution_broadening(
|
|
183
|
+
line_high = get_lines_contribution_broadening(
|
|
184
|
+
*get_lines(idx + 1), energy, abundances, broadening
|
|
185
|
+
)
|
|
249
186
|
|
|
250
187
|
return lerp(kT, kT_low, kT_high, line_low, line_high)
|
|
251
188
|
|
|
252
189
|
|
|
253
|
-
@jax.jit
|
|
254
|
-
@lines_func.defjvp
|
|
255
|
-
def lines_jvp(primals, tangents):
|
|
256
|
-
energy, kT, abundances, broadening = primals
|
|
257
|
-
energy_dot, kT_dot, abundances_dot, broadening_dot = tangents
|
|
258
|
-
|
|
259
|
-
primals_out = lines_func(*primals)
|
|
260
|
-
return primals_out, jnp.zeros_like(primals_out)
|
|
261
|
-
|
|
262
|
-
|
|
263
190
|
class APEC(AdditiveComponent):
|
|
264
191
|
"""
|
|
265
192
|
APEC model implementation in pure JAX for X-ray spectral fitting.
|
|
@@ -276,11 +203,15 @@ class APEC(AdditiveComponent):
|
|
|
276
203
|
thermal_broadening: bool = True,
|
|
277
204
|
turbulent_broadening: bool = True,
|
|
278
205
|
variant: Literal["none", "v", "vv"] = "none",
|
|
279
|
-
abundance_table: Literal[
|
|
206
|
+
abundance_table: Literal[
|
|
207
|
+
"angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
|
|
208
|
+
] = "angr",
|
|
280
209
|
trace_abundance: float = 1.0,
|
|
281
210
|
**kwargs,
|
|
282
211
|
):
|
|
283
|
-
super(
|
|
212
|
+
super().__init__(**kwargs)
|
|
213
|
+
|
|
214
|
+
warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
|
|
284
215
|
|
|
285
216
|
self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
|
|
286
217
|
|
|
@@ -325,14 +256,18 @@ class APEC(AdditiveComponent):
|
|
|
325
256
|
"""
|
|
326
257
|
if self.turbulent_broadening:
|
|
327
258
|
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
328
|
-
return
|
|
259
|
+
return (
|
|
260
|
+
hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
|
|
261
|
+
)
|
|
329
262
|
else:
|
|
330
263
|
return 0.0
|
|
331
264
|
|
|
332
265
|
def get_parameters(self):
|
|
333
266
|
none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
334
267
|
v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
335
|
-
trace_elements =
|
|
268
|
+
trace_elements = (
|
|
269
|
+
jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
|
|
270
|
+
)
|
|
336
271
|
|
|
337
272
|
# Set abundances of trace element (will be overwritten in the vv case)
|
|
338
273
|
abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
|
|
@@ -354,7 +289,9 @@ class APEC(AdditiveComponent):
|
|
|
354
289
|
abund = abund.at[i].set(Z)
|
|
355
290
|
|
|
356
291
|
if abund != "angr":
|
|
357
|
-
abund = abund * jnp.asarray(
|
|
292
|
+
abund = abund * jnp.asarray(
|
|
293
|
+
abundance_table[self.abundance_table] / abundance_table["angr"]
|
|
294
|
+
)
|
|
358
295
|
|
|
359
296
|
# Set the temperature, redshift, normalisation
|
|
360
297
|
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
@@ -372,6 +309,8 @@ class APEC(AdditiveComponent):
|
|
|
372
309
|
|
|
373
310
|
continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
|
|
374
311
|
pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
|
|
375
|
-
lines =
|
|
312
|
+
lines = (
|
|
313
|
+
lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
|
|
314
|
+
)
|
|
376
315
|
|
|
377
316
|
return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
|
|
@@ -1,90 +1,73 @@
|
|
|
1
|
-
"""
|
|
2
|
-
pure callback to enable reading data from the files without saturating the memory.
|
|
1
|
+
"""This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
|
|
2
|
+
pure callback to enable reading data from the files without saturating the memory."""
|
|
3
3
|
|
|
4
|
+
import h5netcdf
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
|
-
import numpy as np
|
|
7
|
-
import importlib.resources
|
|
8
|
-
import xarray as xr
|
|
9
7
|
|
|
8
|
+
from ...util.online_storage import table_manager
|
|
10
9
|
|
|
11
|
-
apec_file = xr.open_dataset(importlib.resources.files("jaxspec") / "tables/apec.nc", engine="h5netcdf")
|
|
12
10
|
|
|
11
|
+
@jax.jit
|
|
12
|
+
def temperature_table_getter():
|
|
13
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
14
|
+
temperature = jnp.asarray(f["/temperature"])
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
idx = np.searchsorted(apec_file.temperature, kT) - 1
|
|
16
|
-
|
|
17
|
-
if idx > len(apec_file.temperature) - 2:
|
|
18
|
-
return idx, 0.0, 0.0
|
|
19
|
-
else:
|
|
20
|
-
return idx, float(apec_file.temperature[idx]), float(apec_file.temperature[idx + 1])
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def continuum_table_getter(idx):
|
|
24
|
-
if idx > len(apec_file.temperature) - 2:
|
|
25
|
-
continuum_energy_array = jnp.zeros_like(apec_file.continuum_energy[0])
|
|
26
|
-
continuum_emissivity_array = jnp.zeros_like(apec_file.continuum_emissivity[0])
|
|
27
|
-
end_index_continuum = jnp.zeros_like(apec_file.continuum_end_index[0])
|
|
28
|
-
|
|
29
|
-
else:
|
|
30
|
-
continuum_energy_array = jnp.asarray(apec_file.continuum_energy[idx])
|
|
31
|
-
continuum_emissivity_array = jnp.asarray(apec_file.continuum_emissivity[idx])
|
|
32
|
-
end_index_continuum = jnp.asarray(apec_file.continuum_end_index[idx])
|
|
33
|
-
|
|
34
|
-
return continuum_energy_array, continuum_emissivity_array, end_index_continuum
|
|
35
|
-
|
|
16
|
+
return temperature
|
|
36
17
|
|
|
37
|
-
def pseudo_table_getter(idx):
|
|
38
|
-
if idx > len(apec_file.temperature) - 2:
|
|
39
|
-
pseudo_energy_array = jnp.zeros_like(apec_file.pseudo_energy[0])
|
|
40
|
-
pseudo_emissivity_array = jnp.zeros_like(apec_file.pseudo_emissivity[0])
|
|
41
|
-
end_index_pseudo = jnp.zeros_like(apec_file.pseudo_end_index[0])
|
|
42
18
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
19
|
+
@jax.jit
|
|
20
|
+
def get_temperature(kT):
|
|
21
|
+
temperature = temperature_table_getter()
|
|
22
|
+
idx = jnp.searchsorted(temperature, kT) - 1
|
|
47
23
|
|
|
48
|
-
return
|
|
24
|
+
return idx, temperature[idx], temperature[idx + 1]
|
|
49
25
|
|
|
50
26
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
27
|
+
@jax.jit
|
|
28
|
+
def continuum_table_getter():
|
|
29
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
30
|
+
continuum_energy = jnp.asarray(f["/continuum_energy"])
|
|
31
|
+
continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
|
|
32
|
+
continuum_end_index = jnp.asarray(f["/continuum_end_index"])
|
|
57
33
|
|
|
58
|
-
|
|
59
|
-
line_energy_array = jnp.asarray(apec_file.line_energy[idx])
|
|
60
|
-
line_element_array = jnp.asarray(apec_file.line_element[idx])
|
|
61
|
-
line_emissivity_array = jnp.asarray(apec_file.line_emissivity[idx])
|
|
62
|
-
end_index_lines = jnp.asarray(apec_file.line_end_index[idx])
|
|
34
|
+
return continuum_energy, continuum_emissivity, continuum_end_index
|
|
63
35
|
|
|
64
|
-
return line_energy_array, line_element_array, line_emissivity_array, end_index_lines
|
|
65
36
|
|
|
37
|
+
@jax.jit
|
|
38
|
+
def pseudo_table_getter():
|
|
39
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
40
|
+
pseudo_energy = jnp.asarray(f["/pseudo_energy"])
|
|
41
|
+
pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
|
|
42
|
+
pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
|
|
66
43
|
|
|
67
|
-
|
|
68
|
-
pure_callback_continuum_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, continuum_table_getter(0)))
|
|
69
|
-
pure_callback_pseudo_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, pseudo_table_getter(0)))
|
|
70
|
-
pure_callback_line_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, lines_table_getter(0)))
|
|
44
|
+
return pseudo_energy, pseudo_emissivity, pseudo_end_index
|
|
71
45
|
|
|
72
46
|
|
|
73
47
|
@jax.jit
|
|
74
|
-
def
|
|
75
|
-
|
|
48
|
+
def line_table_getter():
|
|
49
|
+
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
50
|
+
line_energy = jnp.asarray(f["/line_energy"])
|
|
51
|
+
line_element = jnp.asarray(f["/line_element"])
|
|
52
|
+
line_emissivity = jnp.asarray(f["/line_emissivity"])
|
|
53
|
+
line_end_index = jnp.asarray(f["/line_end_index"])
|
|
54
|
+
|
|
55
|
+
return line_energy, line_element, line_emissivity, line_end_index
|
|
76
56
|
|
|
77
57
|
|
|
78
58
|
@jax.jit
|
|
79
59
|
def get_continuum(idx):
|
|
80
|
-
|
|
60
|
+
continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
|
|
61
|
+
return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
|
|
81
62
|
|
|
82
63
|
|
|
83
64
|
@jax.jit
|
|
84
65
|
def get_pseudo(idx):
|
|
85
|
-
|
|
66
|
+
pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
|
|
67
|
+
return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
|
|
86
68
|
|
|
87
69
|
|
|
88
70
|
@jax.jit
|
|
89
71
|
def get_lines(idx):
|
|
90
|
-
|
|
72
|
+
line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
|
|
73
|
+
return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]
|