jaxspec 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,21 @@
1
- import jax.numpy as jnp
2
- import jax
3
- import haiku as hk
1
+ import warnings
2
+
3
+ from typing import Literal
4
+
4
5
  import astropy.units as u
6
+ import haiku as hk
7
+ import jax
8
+ import jax.numpy as jnp
9
+
10
+ from astropy.constants import c, m_p
11
+ from haiku.initializers import Constant as HaikuConstant
5
12
  from jax import lax
6
- from jax.lax import scan, fori_loop
13
+ from jax.lax import fori_loop, scan
7
14
  from jax.scipy.stats import norm as gaussian
8
- from typing import Literal
15
+
9
16
  from ...util.abundance import abundance_table, element_data
10
- from haiku.initializers import Constant as HaikuConstant
11
- from astropy.constants import c, m_p
12
17
  from ..abc import AdditiveComponent
13
- from .apec_loaders import get_temperature, get_continuum, get_pseudo, get_lines
18
+ from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
14
19
 
15
20
 
16
21
  @jax.jit
@@ -51,7 +56,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
51
56
  # Within
52
57
 
53
58
  current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
54
- previous_energy_is_between = (energy_low <= previous_energy) * (previous_energy < energy_high)
59
+ previous_energy_is_between = (energy_low <= previous_energy) * (
60
+ previous_energy < energy_high
61
+ )
55
62
  energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
56
63
 
57
64
  case = (
@@ -68,7 +75,9 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
68
75
  lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
69
76
  lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
70
77
  lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
71
- lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc)) * (er - el) / 2,
78
+ lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
79
+ * (er - el)
80
+ / 2,
72
81
  # 5
73
82
  ],
74
83
  previous_energy,
@@ -86,23 +95,24 @@ def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end
86
95
  return integrated_flux
87
96
 
88
97
 
98
+ @jax.jit
89
99
  def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
90
100
  energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
91
101
 
92
- return (jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)) / (e_high - e_low)
102
+ return (
103
+ jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
104
+ ) / (e_high - e_low)
93
105
 
94
106
 
95
- def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
107
+ @jax.jit
108
+ def interp_flux(energy, energy_ref, continuum_ref, end_index):
96
109
  """
97
110
  Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
98
111
  """
99
112
 
100
113
  def scanned_func(carry, unpack):
101
114
  e_low, e_high = unpack
102
- if integrate:
103
- continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
104
- else:
105
- continuum = interp(e_low, e_high, energy_ref, continuum_ref, end_index)
115
+ continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
106
116
 
107
117
  return carry, continuum
108
118
 
@@ -111,7 +121,8 @@ def interp_flux(energy, energy_ref, continuum_ref, end_index, integrate=True):
111
121
  return continuum
112
122
 
113
123
 
114
- def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances, integrate=True):
124
+ @jax.jit
125
+ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
115
126
  """
116
127
  Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
117
128
  and weight the flux depending on the abundance of each element
@@ -119,7 +130,7 @@ def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundance
119
130
 
120
131
  def scanned_func(_, unpack):
121
132
  energy_ref, continuum_ref, end_idx = unpack
122
- element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx, integrate=integrate)
133
+ element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
123
134
 
124
135
  return _, element_flux
125
136
 
@@ -136,23 +147,9 @@ def get_lines_contribution_broadening(
136
147
  # Notice the -1 in line element to match the 0-based indexing
137
148
  l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
138
149
  broadening = l_energy * total_broadening[l_element]
139
- l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
140
- l_flux = l_flux * l_emissivity * abundances[l_element]
141
-
142
- return flux + l_flux
143
-
144
- return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
145
-
146
-
147
- @jax.jit
148
- def get_lines_contribution_broadening_derivative(
149
- line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
150
- ):
151
- def body_func(i, flux):
152
- # Notice the -1 in line element to match the 0-based indexing
153
- l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
154
- broadening = l_energy * total_broadening[l_element]
155
- l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(energy[:-1], l_energy, broadening)
150
+ l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
151
+ energy[:-1], l_energy, broadening
152
+ )
156
153
  l_flux = l_flux * l_emissivity * abundances[l_element]
157
154
 
158
155
  return flux + l_flux
@@ -160,7 +157,6 @@ def get_lines_contribution_broadening_derivative(
160
157
  return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
161
158
 
162
159
 
163
- @jax.custom_jvp
164
160
  @jax.jit
165
161
  def continuum_func(energy, kT, abundances):
166
162
  idx, kT_low, kT_high = get_temperature(kT)
@@ -170,96 +166,27 @@ def continuum_func(energy, kT, abundances):
170
166
  return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
171
167
 
172
168
 
173
- @jax.jit
174
- @continuum_func.defjvp
175
- def continuum_jvp(primals, tangents):
176
- energy, kT, abundances = primals
177
- energy_dot, kT_dot, abundances_dot = tangents
178
-
179
- idx, kT_low, kT_high = get_temperature(kT)
180
- continuum_low_pars = get_continuum(idx)
181
- continuum_high_pars = get_continuum(idx + 1)
182
-
183
- # Energy derivative
184
- dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
185
- dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
186
- energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
187
-
188
- # Temperature derivative
189
- continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
190
- continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
191
- kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
192
-
193
- # Abundances derivative
194
- dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
195
- dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
196
- abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
197
-
198
- primals_out = continuum_func(*primals)
199
-
200
- return primals_out, energy_derivative + kT_derivative + abundances_derivative
201
-
202
-
203
- @jax.custom_jvp
204
169
  @jax.jit
205
170
  def pseudo_func(energy, kT, abundances):
206
171
  idx, kT_low, kT_high = get_temperature(kT)
207
- continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
208
- continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
172
+ continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
173
+ continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
209
174
 
210
175
  return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
211
176
 
212
177
 
213
- @jax.jit
214
- @pseudo_func.defjvp
215
- def pseudo_jvp(primals, tangents):
216
- energy, kT, abundances = primals
217
- energy_dot, kT_dot, abundances_dot = tangents
218
-
219
- idx, kT_low, kT_high = get_temperature(kT)
220
- continuum_low_pars = get_pseudo(idx)
221
- continuum_high_pars = get_pseudo(idx + 1)
222
-
223
- # Energy derivative
224
- dcontinuum_low_denerg = interp_flux_elements(*continuum_low_pars, energy, abundances, integrate=False)
225
- dcontinuum_high_denerg = interp_flux_elements(*continuum_high_pars, energy, abundances, integrate=False)
226
- energy_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_denerg, dcontinuum_high_denerg) * jnp.diff(energy_dot)
227
-
228
- # Temperature derivative
229
- continuum_low = interp_flux_elements(*continuum_low_pars, energy, abundances)
230
- continuum_high = interp_flux_elements(*continuum_high_pars, energy, abundances)
231
- kT_derivative = (continuum_high - continuum_low) / (kT_high - kT_low) * kT_dot
232
-
233
- # Abundances derivative
234
- dcontinuum_low_dabund = interp_flux_elements(*continuum_low_pars, energy, abundances_dot)
235
- dcontinuum_high_dabund = interp_flux_elements(*continuum_high_pars, energy, abundances_dot)
236
- abundances_derivative = lerp(kT, kT_low, kT_high, dcontinuum_low_dabund, dcontinuum_high_dabund)
237
-
238
- primals_out = pseudo_func(*primals)
239
-
240
- return primals_out, energy_derivative + kT_derivative + abundances_derivative
241
-
242
-
243
- @jax.custom_jvp
178
+ # @jax.custom_jvp
244
179
  @jax.jit
245
180
  def lines_func(energy, kT, abundances, broadening):
246
181
  idx, kT_low, kT_high = get_temperature(kT)
247
182
  line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
248
- line_high = get_lines_contribution_broadening(*get_lines(idx + 1), energy, abundances, broadening)
183
+ line_high = get_lines_contribution_broadening(
184
+ *get_lines(idx + 1), energy, abundances, broadening
185
+ )
249
186
 
250
187
  return lerp(kT, kT_low, kT_high, line_low, line_high)
251
188
 
252
189
 
253
- @jax.jit
254
- @lines_func.defjvp
255
- def lines_jvp(primals, tangents):
256
- energy, kT, abundances, broadening = primals
257
- energy_dot, kT_dot, abundances_dot, broadening_dot = tangents
258
-
259
- primals_out = lines_func(*primals)
260
- return primals_out, jnp.zeros_like(primals_out)
261
-
262
-
263
190
  class APEC(AdditiveComponent):
264
191
  """
265
192
  APEC model implementation in pure JAX for X-ray spectral fitting.
@@ -276,11 +203,15 @@ class APEC(AdditiveComponent):
276
203
  thermal_broadening: bool = True,
277
204
  turbulent_broadening: bool = True,
278
205
  variant: Literal["none", "v", "vv"] = "none",
279
- abundance_table: Literal["angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"] = "angr",
206
+ abundance_table: Literal[
207
+ "angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
208
+ ] = "angr",
280
209
  trace_abundance: float = 1.0,
281
210
  **kwargs,
282
211
  ):
283
- super(APEC, self).__init__(**kwargs)
212
+ super().__init__(**kwargs)
213
+
214
+ warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
284
215
 
285
216
  self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
286
217
 
@@ -325,14 +256,18 @@ class APEC(AdditiveComponent):
325
256
  """
326
257
  if self.turbulent_broadening:
327
258
  # This return value must be multiplied by the energy of the line to get actual broadening
328
- return hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
259
+ return (
260
+ hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
261
+ )
329
262
  else:
330
263
  return 0.0
331
264
 
332
265
  def get_parameters(self):
333
266
  none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
334
267
  v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
335
- trace_elements = jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
268
+ trace_elements = (
269
+ jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
270
+ )
336
271
 
337
272
  # Set abundances of trace element (will be overwritten in the vv case)
338
273
  abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
@@ -354,7 +289,9 @@ class APEC(AdditiveComponent):
354
289
  abund = abund.at[i].set(Z)
355
290
 
356
291
  if abund != "angr":
357
- abund = abund * jnp.asarray(abundance_table[self.abundance_table] / abundance_table["angr"])
292
+ abund = abund * jnp.asarray(
293
+ abundance_table[self.abundance_table] / abundance_table["angr"]
294
+ )
358
295
 
359
296
  # Set the temperature, redshift, normalisation
360
297
  kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
@@ -372,6 +309,8 @@ class APEC(AdditiveComponent):
372
309
 
373
310
  continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
374
311
  pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
375
- lines = lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
312
+ lines = (
313
+ lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
314
+ )
376
315
 
377
316
  return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
@@ -1,90 +1,73 @@
1
- """ This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
- pure callback to enable reading data from the files without saturating the memory. """
1
+ """This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
+ pure callback to enable reading data from the files without saturating the memory."""
3
3
 
4
+ import h5netcdf
4
5
  import jax
5
6
  import jax.numpy as jnp
6
- import numpy as np
7
- import importlib.resources
8
- import xarray as xr
9
7
 
8
+ from ...util.online_storage import table_manager
10
9
 
11
- apec_file = xr.open_dataset(importlib.resources.files("jaxspec") / "tables/apec.nc", engine="h5netcdf")
12
10
 
11
+ @jax.jit
12
+ def temperature_table_getter():
13
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
14
+ temperature = jnp.asarray(f["/temperature"])
13
15
 
14
- def temperature_table_getter(kT):
15
- idx = np.searchsorted(apec_file.temperature, kT) - 1
16
-
17
- if idx > len(apec_file.temperature) - 2:
18
- return idx, 0.0, 0.0
19
- else:
20
- return idx, float(apec_file.temperature[idx]), float(apec_file.temperature[idx + 1])
21
-
22
-
23
- def continuum_table_getter(idx):
24
- if idx > len(apec_file.temperature) - 2:
25
- continuum_energy_array = jnp.zeros_like(apec_file.continuum_energy[0])
26
- continuum_emissivity_array = jnp.zeros_like(apec_file.continuum_emissivity[0])
27
- end_index_continuum = jnp.zeros_like(apec_file.continuum_end_index[0])
28
-
29
- else:
30
- continuum_energy_array = jnp.asarray(apec_file.continuum_energy[idx])
31
- continuum_emissivity_array = jnp.asarray(apec_file.continuum_emissivity[idx])
32
- end_index_continuum = jnp.asarray(apec_file.continuum_end_index[idx])
33
-
34
- return continuum_energy_array, continuum_emissivity_array, end_index_continuum
35
-
16
+ return temperature
36
17
 
37
- def pseudo_table_getter(idx):
38
- if idx > len(apec_file.temperature) - 2:
39
- pseudo_energy_array = jnp.zeros_like(apec_file.pseudo_energy[0])
40
- pseudo_emissivity_array = jnp.zeros_like(apec_file.pseudo_emissivity[0])
41
- end_index_pseudo = jnp.zeros_like(apec_file.pseudo_end_index[0])
42
18
 
43
- else:
44
- pseudo_energy_array = jnp.asarray(apec_file.pseudo_energy[idx])
45
- pseudo_emissivity_array = jnp.asarray(apec_file.pseudo_emissivity[idx])
46
- end_index_pseudo = jnp.asarray(apec_file.pseudo_end_index[idx])
19
+ @jax.jit
20
+ def get_temperature(kT):
21
+ temperature = temperature_table_getter()
22
+ idx = jnp.searchsorted(temperature, kT) - 1
47
23
 
48
- return pseudo_energy_array, pseudo_emissivity_array, end_index_pseudo
24
+ return idx, temperature[idx], temperature[idx + 1]
49
25
 
50
26
 
51
- def lines_table_getter(idx):
52
- if idx > len(apec_file.temperature) - 2:
53
- line_energy_array = jnp.zeros_like(apec_file.line_energy[0])
54
- line_element_array = jnp.zeros_like(apec_file.line_element[0])
55
- line_emissivity_array = jnp.zeros_like(apec_file.line_emissivity[0])
56
- end_index_lines = jnp.zeros_like(apec_file.line_end_index[0])
27
+ @jax.jit
28
+ def continuum_table_getter():
29
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
30
+ continuum_energy = jnp.asarray(f["/continuum_energy"])
31
+ continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
32
+ continuum_end_index = jnp.asarray(f["/continuum_end_index"])
57
33
 
58
- else:
59
- line_energy_array = jnp.asarray(apec_file.line_energy[idx])
60
- line_element_array = jnp.asarray(apec_file.line_element[idx])
61
- line_emissivity_array = jnp.asarray(apec_file.line_emissivity[idx])
62
- end_index_lines = jnp.asarray(apec_file.line_end_index[idx])
34
+ return continuum_energy, continuum_emissivity, continuum_end_index
63
35
 
64
- return line_energy_array, line_element_array, line_emissivity_array, end_index_lines
65
36
 
37
+ @jax.jit
38
+ def pseudo_table_getter():
39
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
40
+ pseudo_energy = jnp.asarray(f["/pseudo_energy"])
41
+ pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
42
+ pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
66
43
 
67
- pure_callback_temperature_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, temperature_table_getter(10.0)))
68
- pure_callback_continuum_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, continuum_table_getter(0)))
69
- pure_callback_pseudo_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, pseudo_table_getter(0)))
70
- pure_callback_line_shape = jax.eval_shape(lambda: jax.tree.map(jnp.asarray, lines_table_getter(0)))
44
+ return pseudo_energy, pseudo_emissivity, pseudo_end_index
71
45
 
72
46
 
73
47
  @jax.jit
74
- def get_temperature(kT):
75
- return jax.pure_callback(temperature_table_getter, pure_callback_temperature_shape, kT)
48
+ def line_table_getter():
49
+ with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
50
+ line_energy = jnp.asarray(f["/line_energy"])
51
+ line_element = jnp.asarray(f["/line_element"])
52
+ line_emissivity = jnp.asarray(f["/line_emissivity"])
53
+ line_end_index = jnp.asarray(f["/line_end_index"])
54
+
55
+ return line_energy, line_element, line_emissivity, line_end_index
76
56
 
77
57
 
78
58
  @jax.jit
79
59
  def get_continuum(idx):
80
- return jax.pure_callback(continuum_table_getter, pure_callback_continuum_shape, idx)
60
+ continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
61
+ return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
81
62
 
82
63
 
83
64
  @jax.jit
84
65
  def get_pseudo(idx):
85
- return jax.pure_callback(pseudo_table_getter, pure_callback_pseudo_shape, idx)
66
+ pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
67
+ return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
86
68
 
87
69
 
88
70
  @jax.jit
89
71
  def get_lines(idx):
90
- return jax.pure_callback(lines_table_getter, pure_callback_line_shape, idx)
72
+ line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
73
+ return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]
jaxspec/model/additive.py CHANGED
@@ -1,16 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
-
3
+ import astropy.constants
4
+ import astropy.units as u
4
5
  import haiku as hk
5
6
  import jax
6
7
  import jax.numpy as jnp
7
8
  import jax.scipy as jsp
8
- import astropy.units as u
9
- import astropy.constants
9
+
10
10
  from haiku.initializers import Constant as HaikuConstant
11
+
11
12
  from ..util.integrate import integrate_interval
13
+
14
+ # from ._additive.apec import APEC
12
15
  from .abc import AdditiveComponent
13
- from ._additive.apec import APEC # noqa: F401
14
16
 
15
17
 
16
18
  class Powerlaw(AdditiveComponent):
@@ -283,16 +285,16 @@ class Diskpbb(AdditiveComponent):
283
285
  where $$p$$ is a free parameter. The standard disk model, diskbb, is recovered if $$p=0.75$$.
284
286
  If radial advection is important then $$p<0.75$$.
285
287
 
286
- $$\mathcal{M}\left( E \right) = \frac{2\pi(\cos i)r^{2}_{\text{in}}}{pd^2} \int_{T_{\text{in}}}^{T_{\text{out}}}
287
- \left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
288
+ $$\\mathcal{M}\\left( E \right) = \frac{2\\pi(\\cos i)r^{2}_{\text{in}}}{pd^2} \\int_{T_{\text{in}}}^{T_{\text{out}}}
289
+ \\left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
288
290
 
289
291
  ??? abstract "Parameters"
290
- * $\text{norm}$ : $\cos i(r_{\text{in}}/d)^{2}$,
291
- where $r_{\text{in}}$ is "an apparent" inner disk radius $\left[\text{km}\right]$,
292
+ * $\text{norm}$ : $\\cos i(r_{\text{in}}/d)^{2}$,
293
+ where $r_{\text{in}}$ is "an apparent" inner disk radius $\\left[\text{km}\right]$,
292
294
  $d$ the distance to the source in units of $10 \text{kpc}$,
293
295
  $i$ the angle of the disk ($i=0$ is face-on)
294
- * $p$ : Exponent of the radial dependence of the disk temperature $\left[\text{dimensionless}\right]$
295
- * $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
296
+ * $p$ : Exponent of the radial dependence of the disk temperature $\\left[\text{dimensionless}\right]$
297
+ * $T_{\text{in}}$ : Temperature at inner disk radius $\\left[ \\mathrm{keV}\right]$
296
298
  """
297
299
 
298
300
  def continuum(self, energy):
@@ -332,7 +334,13 @@ class Diskbb(AdditiveComponent):
332
334
  return e**2 * (kT / tin) ** (-2 / p - 1) / (jnp.exp(e / kT) - 1)
333
335
 
334
336
  integral = integrate_interval(integrand)
335
- return norm * 2.78e-3 * (0.75 / p) / tin * jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
337
+ return (
338
+ norm
339
+ * 2.78e-3
340
+ * (0.75 / p)
341
+ / tin
342
+ * jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
343
+ )
336
344
 
337
345
 
338
346
  class Agauss(AdditiveComponent):
@@ -381,7 +389,11 @@ class Zagauss(AdditiveComponent):
381
389
  norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
382
390
  redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
383
391
 
384
- return norm * (1 + redshift) * jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
392
+ return (
393
+ norm
394
+ * (1 + redshift)
395
+ * jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
396
+ )
385
397
 
386
398
 
387
399
  class Zgauss(AdditiveComponent):
@@ -404,4 +416,6 @@ class Zgauss(AdditiveComponent):
404
416
  norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
405
417
  redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
406
418
 
407
- return (norm / (1 + redshift)) * jsp.stats.norm.pdf(energy * (1 + redshift), loc=line_energy, scale=sigma)
419
+ return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
420
+ energy * (1 + redshift), loc=line_energy, scale=sigma
421
+ )
@@ -1,9 +1,11 @@
1
1
  from abc import ABC, abstractmethod
2
- from tinygp import kernels, GaussianProcess
3
- from jax.scipy.integrate import trapezoid
4
- import numpyro.distributions as dist
2
+
5
3
  import jax.numpy as jnp
6
4
  import numpyro
5
+ import numpyro.distributions as dist
6
+
7
+ from jax.scipy.integrate import trapezoid
8
+ from tinygp import GaussianProcess, kernels
7
9
 
8
10
 
9
11
  class BackgroundModel(ABC):
@@ -33,7 +35,8 @@ class SubtractedBackground(BackgroundModel):
33
35
 
34
36
  """
35
37
 
36
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
38
+ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
39
+ _, observed_counts = obs.out_energies, obs.folded_background.data
37
40
  numpyro.deterministic(f"{name}", observed_counts)
38
41
 
39
42
  return jnp.zeros_like(observed_counts)
@@ -49,32 +52,37 @@ class BackgroundWithError(BackgroundModel):
49
52
  but slower since it performs the fit using MCMC instead of analytical solution.
50
53
  """
51
54
 
52
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
55
+ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
53
56
  # Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
57
+ _, observed_counts = obs.out_energies, obs.folded_background.data
54
58
  alpha = observed_counts + 1
55
59
  beta = 1
56
60
  countrate = numpyro.sample(f"{name}_params", dist.Gamma(alpha, rate=beta))
57
61
 
58
62
  with numpyro.plate(f"{name}_plate", len(observed_counts)):
59
- numpyro.sample(f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None)
63
+ numpyro.sample(
64
+ f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
65
+ )
60
66
 
61
67
  return countrate
62
68
 
63
69
 
64
70
  '''
71
+ # TODO: Implement this class and sample it with Gibbs Sampling
72
+
65
73
  class ConjugateBackground(BackgroundModel):
66
74
  r"""
67
- This class fit an expected rate $\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
75
+ This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
68
76
  distribution, we can analytically derive the posterior as a Negative binomial distribution.
69
77
 
70
- $$ p(\lambda_{\text{Bkg}}) \sim \Gamma \left( \alpha, \beta \right) \implies
71
- p\left(\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \sim \text{NB}\left(\alpha, \frac{\beta}{\beta +1}
78
+ $$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
79
+ p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
72
80
  \right) $$
73
81
 
74
82
  !!! info
75
83
  Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
76
- the prior distribution is such that $\mathbb{E}[\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
77
- $\text{Var}[\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
84
+ the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
85
+ $\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
78
86
  counts are 0, and add a small scatter even if the measured background is effectively null.
79
87
 
80
88
  ??? abstract "References"
@@ -97,6 +105,19 @@ class ConjugateBackground(BackgroundModel):
97
105
  return countrate
98
106
  '''
99
107
 
108
+ """
109
+ class SpectralBackgroundModel(BackgroundModel):
110
+ # I should pass the current spectral model as an argument to the background model
111
+ # In the numpyro model function
112
+ def __init__(self, model, prior):
113
+ self.model = model
114
+ self.prior = prior
115
+
116
+ def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
117
+ #TODO : keep the sparsification from top model
118
+ transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
119
+ """
120
+
100
121
 
101
122
  class GaussianProcessBackground(BackgroundModel):
102
123
  """
@@ -104,7 +125,13 @@ class GaussianProcessBackground(BackgroundModel):
104
125
  [`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
105
126
  """
106
127
 
107
- def __init__(self, e_min: float, e_max: float, n_nodes: int = 30, kernel: kernels.Kernel = kernels.Matern52):
128
+ def __init__(
129
+ self,
130
+ e_min: float,
131
+ e_max: float,
132
+ n_nodes: int = 30,
133
+ kernel: kernels.Kernel = kernels.Matern52,
134
+ ):
108
135
  """
109
136
  Build the Gaussian Process background model.
110
137
 
@@ -119,7 +146,7 @@ class GaussianProcessBackground(BackgroundModel):
119
146
  self.n_nodes = n_nodes
120
147
  self.kernel = kernel
121
148
 
122
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
149
+ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
123
150
  """
124
151
  Build the model for the background.
125
152
 
@@ -129,9 +156,12 @@ class GaussianProcessBackground(BackgroundModel):
129
156
  name: The name of the background model for parameters disambiguation.
130
157
  observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
131
158
  """
159
+ energy, observed_counts = obs.out_energies, obs.folded_background.data
132
160
 
133
161
  if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
134
- raise RuntimeError("More nodes than channels in the observation associated with GaussianProcessBackground.")
162
+ raise RuntimeError(
163
+ "More nodes than channels in the observation associated with GaussianProcessBackground."
164
+ )
135
165
 
136
166
  # The parameters of the GP model
137
167
  mean = numpyro.sample(f"{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0))
@@ -144,13 +174,17 @@ class GaussianProcessBackground(BackgroundModel):
144
174
  gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
145
175
 
146
176
  log_rate = numpyro.sample(f"_{name}_log_rate_nodes", gp.numpyro_dist())
147
- interp_count_rate = jnp.exp(jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate))
177
+ interp_count_rate = jnp.exp(
178
+ jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
179
+ )
148
180
  count_rate = trapezoid(interp_count_rate, energy, axis=0)
149
181
 
150
182
  # Finally, our observation model is Poisson
151
183
  with numpyro.plate(f"{name}_plate", len(observed_counts)):
152
184
  # TODO : change to Poisson Likelihood when there is no background model
153
185
  # TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
154
- numpyro.sample(f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None)
186
+ numpyro.sample(
187
+ f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
188
+ )
155
189
 
156
190
  return count_rate