jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,21 @@
1
- import jax.numpy as jnp
2
- import jax
3
- import haiku as hk
1
+ import warnings
2
+
3
+ from typing import Literal
4
+
4
5
  import astropy.units as u
6
+ import haiku as hk
7
+ import jax
8
+ import jax.numpy as jnp
9
+
10
+ from astropy.constants import c, m_p
11
+ from haiku.initializers import Constant as HaikuConstant
5
12
  from jax import lax
6
- from jax.lax import scan, fori_loop
13
+ from jax.lax import fori_loop, scan
7
14
  from jax.scipy.stats import norm as gaussian
8
- from typing import Literal
15
+
9
16
  from ...util.abundance import abundance_table, element_data
10
- from haiku.initializers import Constant as HaikuConstant
11
- from astropy.constants import c, m_p
12
17
  from ..abc import AdditiveComponent
13
- from .apec_loaders import get_temperature, get_continuum, get_pseudo, get_lines
18
+ from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
14
19
 
15
20
 
16
21
  @jax.jit
@@ -51,7 +56,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
51
56
  # Within
52
57
 
53
58
  current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
54
- previous_energy_is_between = (energy_low <= previous_energy) * (previous_energy < energy_high)
59
+ previous_energy_is_between = (energy_low <= previous_energy) * (
60
+ previous_energy < energy_high
61
+ )
55
62
  energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
56
63
 
57
64
  case = (
@@ -68,7 +75,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
68
75
  lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
69
76
  lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
70
77
  lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
71
- lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc)) * (er - el) / 2,
78
+ lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
79
+ * (er - el)
80
+ / 2,
72
81
  # 5
73
82
  ],
74
83
  previous_energy,
@@ -86,23 +95,24 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
86
95
  return integrated_flux
87
96
 
88
97
 
98
+ @jax.jit
89
99
  def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
90
100
  energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
91
101
 
92
- return (jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)) / (e_high - e_low)
102
+ return (
103
+ jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
104
+ ) / (e_high - e_low)
93
105
 
94
106
 
95
- def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
107
+ @jax.jit
108
+ def interp_flux(energy, energy_ref, continuum_ref, end_index):
96
109
  """
97
110
  Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
98
111
  """
99
112
 
100
113
  def scanned_func(carry, unpack):
101
114
  e_low, e_high = unpack
102
- if integrate:
103
- continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
104
- else:
105
- continuum = interp(e_low, e_high, energy_ref, continuum_ref, end_index)
115
+ continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
106
116
 
107
117
  return carry, continuum
108
118
 
@@ -111,7 +121,8 @@ def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
111
121
  return continuum
112
122
 
113
123
 
114
- def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances, integrate=True):
124
+ @jax.jit
125
+ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
115
126
  """
116
127
  Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
117
128
  and weight the flux depending on the abundance of each element
@@ -119,7 +130,7 @@ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundance
119
130
 
120
131
  def scanned_func(_, unpack):
121
132
  energy_ref, continuum_ref, end_idx = unpack
122
- element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx, integrate=integrate)
133
+ element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
123
134
 
124
135
  return _, element_flux
125
136
 
@@ -136,23 +147,9 @@ def get_lines_contribution_broadening(
136
147
  # Notice the -1 in line element to match the 0-based indexing
137
148
  l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
138
149
  broadening = l_energy * total_broadening[l_element]
139
- l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
140
- l_flux = l_flux * l_emissivity * abundances[l_element]
141
-
142
- return flux + l_flux
143
-
144
- return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
145
-
146
-
147
- @jax.jit
148
- def get_lines_contribution_broadening_derivative(
149
- line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
150
- ):
151
- def body_func(i, flux):
152
- # Notice the -1 in line element to match the 0-based indexing
153
- l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
154
- broadening = l_energy * total_broadening[l_element]
155
- l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
150
+ l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
151
+ energy[:-1], l_energy, broadening
152
+ )
156
153
  l_flux = l_flux * l_emissivity * abundances[l_element]
157
154
 
158
155
  return flux + l_flux
@@ -160,7 +157,6 @@ def get_lines_contribution_broadening_derivative(
160
157
  return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
161
158
 
162
159
 
163
- @jax.custom_jvp
164
160
  @jax.jit
165
161
  def continuum_func(energy, kT, abundances):
166
162
  idx, kT_low, kT_high = get_temperature(kT)
@@ -170,96 +166,27 @@ def continuum_func(energy, kT, abundances):
170
166
  return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
171
167
 
172
168
 
173
- @jax.jit
174
- @continuum_func.defjvp
175
- def continuum_jvp(primals, tangents):
176
- energy, kT, abundances = primals
177
- energy_dot, kT_dot, abundances_dot = tangents
178
-
179
- idx, kT_low, kT_high = get_temperature(kT)
180
- continuum_low_pars = get_continuum(idx)
181
- continuum_high_pars = get_continuum(idx + 1)
182
-
183
- # Energy derivative
184
- dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
185
- dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
186
- energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
187
-
188
- # Temperature derivative
189
- continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
190
- continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
191
- kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
192
-
193
- # Abundances derivative
194
- dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
195
- dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
196
- abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
197
-
198
- primals_out = continuum_func(*primals)
199
-
200
- return primals_out, energy_derivative + kT_derivative + abundances_derivative
201
-
202
-
203
- @jax.custom_jvp
204
169
  @jax.jit
205
170
  def pseudo_func(energy, kT, abundances):
206
171
  idx, kT_low, kT_high = get_temperature(kT)
207
- continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
208
- continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
172
+ continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
173
+ continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
209
174
 
210
175
  return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
211
176
 
212
177
 
213
- @jax.jit
214
- @pseudo_func.defjvp
215
- def pseudo_jvp(primals, tangents):
216
- energy, kT, abundances = primals
217
- energy_dot, kT_dot, abundances_dot = tangents
218
-
219
- idx, kT_low, kT_high = get_temperature(kT)
220
- continuum_low_pars = get_pseudo(idx)
221
- continuum_high_pars = get_pseudo(idx + 1)
222
-
223
- # Energy derivative
224
- dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
225
- dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
226
- energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
227
-
228
- # Temperature derivative
229
- continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
230
- continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
231
- kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
232
-
233
- # Abundances derivative
234
- dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
235
- dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
236
- abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
237
-
238
- primals_out = pseudo_func(*primals)
239
-
240
- return primals_out, energy_derivative + kT_derivative + abundances_derivative
241
-
242
-
243
- @jax.custom_jvp
178
+ # @jax.custom_jvp
244
179
  @jax.jit
245
180
  def lines_func(energy, kT, abundances, broadening):
246
181
  idx, kT_low, kT_high = get_temperature(kT)
247
182
  line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
248
- line_high = get_lines_contribution_broadening(*get_lines(idx + 1), energy, abundances, broadening)
183
+ line_high = get_lines_contribution_broadening(
184
+ *get_lines(idx + 1), energy, abundances, broadening
185
+ )
249
186
 
250
187
  return lerp(kT, kT_low, kT_high, line_low, line_high)
251
188
 
252
189
 
253
- @jax.jit
254
- @lines_func.defjvp
255
- def lines_jvp(primals, tangents):
256
- energy, kT, abundances, broadening = primals
257
- energy_dot, kT_dot, abundances_dot, broadening_dot = tangents
258
-
259
- primals_out = lines_func(*primals)
260
- return primals_out, jnp.zeros_like(primals_out)
261
-
262
-
263
190
  class APEC(AdditiveComponent):
264
191
  """
265
192
  APEC model implementation in pure JAX for X-ray spectral fitting.
@@ -276,11 +203,15 @@ class APEC(AdditiveComponent):
276
203
  thermal_broadening: bool = True,
277
204
  turbulent_broadening: bool = True,
278
205
  variant: Literal["none", "v", "vv"] = "none",
279
- abundance_table: Literal["angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"] = "angr",
206
+ abundance_table: Literal[
207
+ "angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
208
+ ] = "angr",
280
209
  trace_abundance: float = 1.0,
281
210
  **kwargs,
282
211
  ):
283
- super(APEC, self).__init__(**kwargs)
212
+ super().__init__(**kwargs)
213
+
214
+ warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
284
215
 
285
216
  self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
286
217
 
@@ -325,14 +256,18 @@ class APEC(AdditiveComponent):
325
256
  """
326
257
  if self.turbulent_broadening:
327
258
  # This return value must be multiplied by the energy of the line to get actual broadening
328
- return hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
259
+ return (
260
+ hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
261
+ )
329
262
  else:
330
263
  return 0.0
331
264
 
332
265
  def get_parameters(self):
333
266
  none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
334
267
  v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
335
- trace_elements = jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
268
+ trace_elements = (
269
+ jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
270
+ )
336
271
 
337
272
  # Set abundances of trace element (will be overwritten in the vv case)
338
273
  abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
@@ -354,7 +289,9 @@ class APEC(AdditiveComponent):
354
289
  abund = abund.at[i].set(Z)
355
290
 
356
291
  if abund != "angr":
357
- abund = abund * jnp.asarray(abundance_table[self.abundance_table] / abundance_table["angr"])
292
+ abund = abund * jnp.asarray(
293
+ abundance_table[self.abundance_table] / abundance_table["angr"]
294
+ )
358
295
 
359
296
  # Set the temperature, redshift, normalisation
360
297
  kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
@@ -372,6 +309,8 @@ class APEC(AdditiveComponent):
372
309
 
373
310
  continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
374
311
  pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
375
- lines = lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
312
+ lines = (
313
+ lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
314
+ )
376
315
 
377
316
  return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
@@ -1,90 +1,73 @@
1
- """ This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
- pure callback to enable reading data from the files without saturating the memory. """
1
+ """This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
+ pure callback to enable reading data from the files without saturating the memory."""
3
3
 
4
+ import h5netcdf
4
5
  import jax
5
6
  import jax.numpy as jnp
6
- import numpy as np
7
- import importlib.resources
8
- import xarray as xr
9
7
 
8
+ from ...util.online_storage import table_manager
10
9
 
11
- apec_file = xr.open_dataset(importlib.resources.files("jaxspec") / "tables/apec.nc", engine="h5netcdf")
12
10
 
11
+ @jax.jit
12
+ def temperature_table_getter():
13
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
14
+ temperature = jnp.asarray(f["/temperature"])
13
15
 
14
- def temperature_table_getter(kT):
15
- idx = np.searchsorted(apec_file.temperature, kT) - 1
16
-
17
- if idx > len(apec_file.temperature) - 2:
18
- return idx, 0.0, 0.0
19
- else:
20
- return idx, float(apec_file.temperature[idx]), float(apec_file.temperature[idx + 1])
21
-
22
-
23
- def continuum_table_getter(idx):
24
- if idx > len(apec_file.temperature) - 2:
25
- continuum_energy_array = jnp.zeros_like(apec_file.continuum_energy[0])
26
- continuum_emissivity_array = jnp.zeros_like(apec_file.continuum_emissivity[0])
27
- end_index_continuum = jnp.zeros_like(apec_file.continuum_end_index[0])
28
-
29
- else:
30
- continuum_energy_array = jnp.asarray(apec_file.continuum_energy[idx])
31
- continuum_emissivity_array = jnp.asarray(apec_file.continuum_emissivity[idx])
32
- end_index_continuum = jnp.asarray(apec_file.continuum_end_index[idx])
33
-
34
- return continuum_energy_array, continuum_emissivity_array, end_index_continuum
35
-
16
+ return temperature
36
17
 
37
- def pseudo_table_getter(idx):
38
- if idx > len(apec_file.temperature) - 2:
39
- pseudo_energy_array = jnp.zeros_like(apec_file.pseudo_energy[0])
40
- pseudo_emissivity_array = jnp.zeros_like(apec_file.pseudo_emissivity[0])
41
- end_index_pseudo = jnp.zeros_like(apec_file.pseudo_end_index[0])
42
18
 
43
- else:
44
- pseudo_energy_array = jnp.asarray(apec_file.pseudo_energy[idx])
45
- pseudo_emissivity_array = jnp.asarray(apec_file.pseudo_emissivity[idx])
46
- end_index_pseudo = jnp.asarray(apec_file.pseudo_end_index[idx])
19
+ @jax.jit
20
+ def get_temperature(kT):
21
+ temperature = temperature_table_getter()
22
+ idx = jnp.searchsorted(temperature, kT) - 1
47
23
 
48
- return pseudo_energy_array, pseudo_emissivity_array, end_index_pseudo
24
+ return idx, temperature[idx], temperature[idx + 1]
49
25
 
50
26
 
51
- def lines_table_getter(idx):
52
- if idx > len(apec_file.temperature) - 2:
53
- line_energy_array = jnp.zeros_like(apec_file.line_energy[0])
54
- line_element_array = jnp.zeros_like(apec_file.line_element[0])
55
- line_emissivity_array = jnp.zeros_like(apec_file.line_emissivity[0])
56
- end_index_lines = jnp.zeros_like(apec_file.line_end_index[0])
27
+ @jax.jit
28
+ def continuum_table_getter():
29
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
30
+ continuum_energy = jnp.asarray(f["/continuum_energy"])
31
+ continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
32
+ continuum_end_index = jnp.asarray(f["/continuum_end_index"])
57
33
 
58
- else:
59
- line_energy_array = jnp.asarray(apec_file.line_energy[idx])
60
- line_element_array = jnp.asarray(apec_file.line_element[idx])
61
- line_emissivity_array = jnp.asarray(apec_file.line_emissivity[idx])
62
- end_index_lines = jnp.asarray(apec_file.line_end_index[idx])
34
+ return continuum_energy, continuum_emissivity, continuum_end_index
63
35
 
64
- return line_energy_array, line_element_array, line_emissivity_array, end_index_lines
65
36
 
37
+ @jax.jit
38
+ def pseudo_table_getter():
39
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
40
+ pseudo_energy = jnp.asarray(f["/pseudo_energy"])
41
+ pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
42
+ pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
66
43
 
67
- pure_callback_temperature_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, temperature_table_getter(10.0)))
68
- pure_callback_continuum_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, continuum_table_getter(0)))
69
- pure_callback_pseudo_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, pseudo_table_getter(0)))
70
- pure_callback_line_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, lines_table_getter(0)))
44
+ return pseudo_energy, pseudo_emissivity, pseudo_end_index
71
45
 
72
46
 
73
47
  @jax.jit
74
- def get_temperature(kT):
75
- return jax.pure_callback(temperature_table_getter, pure_callback_temperature_shape, kT)
48
+ def line_table_getter():
49
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
50
+ line_energy = jnp.asarray(f["/line_energy"])
51
+ line_element = jnp.asarray(f["/line_element"])
52
+ line_emissivity = jnp.asarray(f["/line_emissivity"])
53
+ line_end_index = jnp.asarray(f["/line_end_index"])
54
+
55
+ return line_energy, line_element, line_emissivity, line_end_index
76
56
 
77
57
 
78
58
  @jax.jit
79
59
  def get_continuum(idx):
80
- return jax.pure_callback(continuum_table_getter, pure_callback_continuum_shape, idx)
60
+ continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
61
+ return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
81
62
 
82
63
 
83
64
  @jax.jit
84
65
  def get_pseudo(idx):
85
- return jax.pure_callback(pseudo_table_getter, pure_callback_pseudo_shape, idx)
66
+ pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
67
+ return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
86
68
 
87
69
 
88
70
  @jax.jit
89
71
  def get_lines(idx):
90
- return jax.pure_callback(lines_table_getter, pure_callback_line_shape, idx)
72
+ line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
73
+ return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]