jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/model/additive.py CHANGED
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import partial
4
+
3
5
  import astropy.constants
4
6
  import astropy.units as u
5
- import haiku as hk
7
+ import flax.nnx as nnx
6
8
  import interpax
7
9
  import jax
8
10
  import jax.numpy as jnp
@@ -10,7 +12,6 @@ import jax.scipy as jsp
10
12
  import numpy as np
11
13
 
12
14
  from astropy.table import Table
13
- from haiku.initializers import Constant as HaikuConstant
14
15
 
15
16
  from ..util.integrate import integrate_interval
16
17
  from ..util.online_storage import table_manager
@@ -23,17 +24,18 @@ class Powerlaw(AdditiveComponent):
23
24
 
24
25
  $$\mathcal{M}\left( E \right) = K \left( \frac{E}{E_0} \right)^{-\alpha}$$
25
26
 
26
- ??? abstract "Parameters"
27
- * $\alpha$ : Photon index of the power law $\left[\text{dimensionless}\right]$
28
- * $E_0$ : Reference energy fixed at 1 keV $\left[ \mathrm{keV}\right]$
29
- * $K$ : Normalization at the reference energy (1 keV) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
27
+ !!! abstract "Parameters"
28
+ * $\alpha$ (`alpha`) $\left[\text{dimensionless}\right]$ : Photon index of the power law
29
+ * $E_0$ $\left[ \mathrm{keV}\right]$ : Reference energy fixed at 1 keV
30
+ * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization at the reference energy (1 keV)
30
31
  """
31
32
 
32
- def continuum(self, energy):
33
- alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
34
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
33
+ def __init__(self):
34
+ self.alpha = nnx.Param(1.7)
35
+ self.norm = nnx.Param(1e-4)
35
36
 
36
- return norm * energy ** (-alpha)
37
+ def continuum(self, energy):
38
+ return self.norm * energy ** (-self.alpha)
37
39
 
38
40
 
39
41
  class Additiveconstant(AdditiveComponent):
@@ -42,19 +44,15 @@ class Additiveconstant(AdditiveComponent):
42
44
 
43
45
  $$\mathcal{M}\left( E \right) = K$$
44
46
 
45
- ??? abstract "Parameters"
46
- * $K$ : Normalization $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
47
+ !!! abstract "Parameters"
48
+ * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization
47
49
  """
48
50
 
49
- def continuum(self, energy):
50
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
51
-
52
- return norm * jnp.ones_like(energy)
53
-
54
- def primitive(self, energy):
55
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
51
+ def __init__(self):
52
+ self.norm = nnx.Param(1.0)
56
53
 
57
- return norm * energy
54
+ def integrated_continuum(self, e_low, e_high):
55
+ return (e_high - e_low) * self.norm
58
56
 
59
57
 
60
58
  class Lorentz(AdditiveComponent):
@@ -63,18 +61,24 @@ class Lorentz(AdditiveComponent):
63
61
 
64
62
  $$\mathcal{M}\left( E \right) = K\frac{\frac{\sigma}{2\pi}}{(E-E_L)^2 + \left(\frac{\sigma}{2}\right)^2}$$
65
63
 
66
- ??? abstract "Parameters"
67
- - $E_L$ : Energy of the line $\left[\text{keV}\right]$
68
- - $\sigma$ : FWHM of the line $\left[\text{keV}\right]$
69
- - $K$ : Normalization $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
64
+ !!! abstract "Parameters"
65
+ - $E_L$ (`E_l`) $\left[\text{keV}\right]$ : Energy of the line
66
+ - $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : FWHM of the line
67
+ - $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
70
68
  """
71
69
 
72
- def continuum(self, energy):
73
- line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
74
- sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
75
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
70
+ def __init__(self):
71
+ self.E_l = jnp.asarray(nnx.Param(1.0), dtype=jnp.float64)
72
+ self.sigma = jnp.asarray(nnx.Param(1e-3), dtype=jnp.float64)
73
+ self.norm = jnp.asarray(nnx.Param(1.0), dtype=jnp.float64)
76
74
 
77
- return norm * sigma / (2 * jnp.pi) / ((energy - line_energy) ** 2 + (sigma / 2) ** 2)
75
+ def continuum(self, energy):
76
+ return (
77
+ self.norm
78
+ * self.sigma
79
+ / (2 * jnp.pi)
80
+ / ((energy - self.E_l) ** 2 + (self.sigma / 2) ** 2)
81
+ )
78
82
 
79
83
 
80
84
  class Logparabola(AdditiveComponent):
@@ -83,23 +87,23 @@ class Logparabola(AdditiveComponent):
83
87
 
84
88
  $$
85
89
  \mathcal{M}\left( E \right) = K \left( \frac{E}{E_{\text{Pivot}}} \right)
86
- ^{-(\alpha + \beta ~ \log (E/E_{\text{Pivot}})) }
90
+ ^{-(\alpha - \beta ~ \log (E/E_{\text{Pivot}})) }
87
91
  $$
88
92
 
89
- ??? abstract "Parameters"
90
- * $a$ : Slope of the LogParabola at the pivot energy $\left[\text{dimensionless}\right]$
91
- * $b$ : Curve parameter of the LogParabola $\left[\text{dimensionless}\right]$
92
- * $E_{\text{Pivot}}$ : Pivot energy fixed at 1 keV $\left[ \mathrm{keV}\right]$
93
- * $K$ : Normalization at the pivot energy (1keV) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
93
+ !!! abstract "Parameters"
94
+ * $a$ (`a`) $\left[\text{dimensionless}\right]$ : Slope of the LogParabola at the pivot energy
95
+ * $b$ (`b`) $\left[\text{dimensionless}\right]$ : Curve parameter of the LogParabola
96
+ * $E_{\text{Pivot}}$ $\left[ \mathrm{keV}\right]$ : Pivot energy fixed at 1 keV
97
+ * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization
94
98
  """
95
99
 
96
- # TODO : conform with xspec definition
97
- def continuum(self, energy):
98
- a = hk.get_parameter("a", [], float, init=HaikuConstant(11 / 3))
99
- b = hk.get_parameter("b", [], float, init=HaikuConstant(0.2))
100
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
100
+ def __init__(self):
101
+ self.a = nnx.Param(1.0)
102
+ self.b = nnx.Param(1.0)
103
+ self.norm = nnx.Param(1.0)
101
104
 
102
- return norm * energy ** (-(a + b * jnp.log(energy)))
105
+ def continuum(self, energy):
106
+ return self.norm * energy ** (-(self.a - self.b * jnp.log(energy)))
103
107
 
104
108
 
105
109
  class Blackbody(AdditiveComponent):
@@ -108,17 +112,20 @@ class Blackbody(AdditiveComponent):
108
112
 
109
113
  $$\mathcal{M}\left( E \right) = \frac{K \times 8.0525 E^{2}}{(k_B T)^{4}\left(\exp(E/k_BT)-1\right)}$$
110
114
 
111
- ??? abstract "Parameters"
112
- * $k_B T$ : Temperature $\left[\text{keV}\right]$
113
- * $K$ : $L_{39}/D_{10}^{2}$, where $L_{39}$ is the source luminosity in units of $10^{39}$ erg/s
114
- and $D_{10}$ is the distance to the source in units of 10 kpc
115
+ !!! abstract "Parameters"
116
+ * $k_B T$ (`kT`) $\left[\text{keV}\right]$ : Temperature
117
+ * $K$ (`norm`) $\left[\text{dimensionless}\right]$ : $L_{39}/D_{10}^{2}$, where $L_{39}$ is the source luminosity [$10^{39} \frac{\text{erg}}{\text{s}}$]
118
+ and $D_{10}$ is the distance to the source [$10 \text{kpc}$]
115
119
  """
116
120
 
117
- def continuum(self, energy):
118
- kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
119
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
121
+ # TODO : rewrite constant as a astropy unit stuff
120
122
 
121
- return norm * 8.0525 * energy**2 / ((kT**4) * (jnp.exp(energy / kT) - 1))
123
+ def __init__(self):
124
+ self.kT = nnx.Param(0.5)
125
+ self.norm = nnx.Param(1.0)
126
+
127
+ def continuum(self, energy):
128
+ return self.norm * 8.0525 * energy**2 / ((self.kT**4) * jnp.expm1(energy / self.kT))
122
129
 
123
130
 
124
131
  class Blackbodyrad(AdditiveComponent):
@@ -127,137 +134,52 @@ class Blackbodyrad(AdditiveComponent):
127
134
 
128
135
  $$\mathcal{M}\left( E \right) = \frac{K \times 1.0344\times 10^{-3} E^{2}}{\left(\exp (E/k_BT)-1\right)}$$
129
136
 
130
- ??? abstract "Parameters"
131
- * $k_B T$ : Temperature $\left[\text{keV}\right]$
132
- * $K$ : $R^2_{km}/D_{10}^{2}$, where $R_{km}$ is the source radius in km
133
- and $D_{10}$ is the distance to the source in units of 10 kpc [dimensionless]
137
+ !!! abstract "Parameters"
138
+ * $k_B T$ (`kT`) $\left[\text{keV}\right]$ : Temperature
139
+ * $K$ (`norm`) [dimensionless] : $R^2_{km}/D_{10}^{2}$, where $R_{km}$ is the source radius [$\text{km}$]
140
+ and $D_{10}$ is the distance to the source [$10 \text{kpc}$]
134
141
  """
135
142
 
136
- def continuum(self, energy):
137
- kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
138
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
143
+ def __init__(self):
144
+ self.kT = nnx.Param(0.5)
145
+ self.norm = nnx.Param(1.0)
139
146
 
140
- return norm * 1.0344e-3 * energy**2 / (jnp.exp(energy / kT) - 1)
147
+ def continuum(self, energy):
148
+ return self.norm * 1.0344e-3 * energy**2 / jnp.expm1(energy / self.kT)
141
149
 
142
150
 
143
151
  class Gauss(AdditiveComponent):
144
152
  r"""
145
- A Gaussian line profile. If the width is $$\leq 0$$ then it is treated as a delta function.
153
+ A Gaussian line profile. If the width is $\leq 0$ then it is treated as a delta function.
146
154
  The `Zgauss` variant computes a redshifted Gaussian.
147
155
 
148
156
  $$\mathcal{M}\left( E \right) = \frac{K}{\sigma \sqrt{2 \pi}}\exp\left(\frac{-(E-E_L)^2}{2\sigma^2}\right)$$
149
157
 
150
- ??? abstract "Parameters"
151
- * $E_L$ : Energy of the line $\left[\text{keV}\right]$
152
- * $\sigma$ : Width of the line $\left[\text{keV}\right]$
153
- * $K$ : Normalization $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
158
+ !!! abstract "Parameters"
159
+ * $E_L$ (`E_l`) $\left[\text{keV}\right]$ : Energy of the line
160
+ * $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Width of the line
161
+ * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
154
162
  """
155
163
 
156
- def continuum(self, energy) -> (jax.Array, jax.Array):
157
- line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
158
- sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
159
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
160
-
161
- return norm * jsp.stats.norm.pdf(energy, loc=line_energy, scale=sigma)
162
-
163
-
164
- """
165
- class APEC(AdditiveComponent):
166
- def __init__(self, name="apec"):
167
- super(APEC, self).__init__(name=name)
168
-
169
- ref = importlib.resources.files("jaxspec") / "tables/apec_tab.npz"
170
- with importlib.resources.as_file(ref) as path:
171
- files = np.load(path)
172
-
173
- self.kT_ref = files["kT_ref"]
174
- self.e_ref = np.nan_to_num(files["continuum_energy"], nan=1e6)
175
- self.c_ref = np.nan_to_num(files["continuum_emissivity"])
176
- self.pe_ref = np.nan_to_num(files["pseudo_energy"], nan=1e6)
177
- self.pc_ref = np.nan_to_num(files["pseudo_emissivity"])
178
- self.energy_lines = np.nan_to_num(files["lines_energy"], nan=1e6) # .astype(np.float32))
179
- self.epsilon_lines = np.nan_to_num(files["lines_emissivity"]) # .astype(np.float32))
180
- self.element_lines = np.nan_to_num(files["lines_element"]) # .astype(np.int32))
181
-
182
- del files
183
-
184
- self.trace_elements = jnp.array([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30]) - 1
185
- self.metals = np.array([6, 7, 8, 10, 12, 13, 14, 16, 18, 20, 26, 28]) - 1 # Element number to python index
186
- self.metals_one_hot = np.zeros((30,))
187
- self.metals_one_hot[self.metals] = 1
188
-
189
- def interp_on_cubes(self, energy, energy_cube, continuum_cube):
190
- # Changer en loginterp
191
- # Interpoler juste sur les points qui ne sont pas tabulés
192
- # Ajouter l'info de la taille max des données (à resortir dans la routine qui trie les fichier apec)
193
- return jnp.vectorize(
194
- lambda ecube, ccube: jnp.interp(energy, ecube, ccube),
195
- signature="(k),(k)->()",
196
- )(energy_cube, continuum_cube)
197
-
198
- def reduction_with_elements(self, Z, energy, energy_cube, continuum_cube):
199
- return jnp.sum(
200
- self.interp_on_cubes(energy, energy_cube, continuum_cube) * jnp.where(self.metals_one_hot, Z, 1.0)[None, :],
201
- axis=-1,
202
- )
203
-
204
- def mono_fine_structure(self, e_low, e_high) -> (jax.Array, jax.Array):
205
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
206
- kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
207
- Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
208
-
209
- idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1
210
-
211
- energy = jax_slice(self.energy_lines, idx, 2)
212
- epsilon = jax_slice(self.epsilon_lines, idx, 2)
213
- element = jax_slice(self.element_lines, idx, 2) - 1
214
-
215
- emissivity_in_bins = jnp.where((e_low < energy) & (energy < e_high), True, False) * epsilon
216
- flux_at_edges = jnp.nansum(
217
- jnp.where(jnp.isin(element, self.metals), Z, 1) * emissivity_in_bins,
218
- axis=-1,
219
- ) # Coeff for metallicity
220
-
221
- return (
222
- jnp.interp(kT, jax_slice(self.kT_ref, idx, 2), flux_at_edges) * 1e14 * norm,
223
- (e_low + e_high) / 2,
224
- )
225
-
226
- def emission_lines(self, e_low, e_high) -> (jax.Array, jax.Array):
227
- # Compute the fine structure lines with e_low and e_high as array, mapping the mono_fine_structure function
228
- # over the various axes of e_low and e_high
229
-
230
- return jnp.vectorize(self.mono_fine_structure)(e_low, e_high)
231
-
232
- # return jnp.zeros_like(e_low), (e_low + e_high)/2
233
-
234
- @partial(jnp.vectorize, excluded=(0,))
235
- def continuum(self, energy):
236
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
237
- kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
238
- Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
239
-
240
- idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1 # index of left value
241
-
242
- continuum = jnp.interp(
243
- kT,
244
- jax_slice(self.kT_ref, idx, 2),
245
- self.reduction_with_elements(Z, energy, jax_slice(self.e_ref, idx, 2), jax_slice(self.c_ref, idx, 2)),
246
- )
247
- pseudo = jnp.interp(
248
- kT,
249
- jax_slice(self.kT_ref, idx, 2),
250
- self.reduction_with_elements(
251
- Z,
252
- energy,
253
- jax_slice(self.pe_ref, idx, 2),
254
- jax_slice(self.pc_ref, idx, 2),
255
- ),
164
+ def __init__(self):
165
+ self.E_l = nnx.Param(2.0)
166
+ self.sigma = nnx.Param(1e-2)
167
+ self.norm = nnx.Param(1.0)
168
+
169
+ def integrated_continuum(self, e_low, e_high):
170
+ return self.norm * (
171
+ jsp.stats.norm.cdf(
172
+ e_high,
173
+ loc=jnp.asarray(self.E_l, dtype=jnp.float64),
174
+ scale=jnp.asarray(self.sigma, dtype=jnp.float64),
175
+ )
176
+ - jsp.stats.norm.cdf(
177
+ e_low,
178
+ loc=jnp.asarray(self.E_l, dtype=jnp.float64),
179
+ scale=jnp.asarray(self.sigma, dtype=jnp.float64),
180
+ )
256
181
  )
257
182
 
258
- return (continuum + pseudo) * 1e14 * norm
259
- """
260
-
261
183
 
262
184
  class Cutoffpl(AdditiveComponent):
263
185
  r"""
@@ -265,71 +187,40 @@ class Cutoffpl(AdditiveComponent):
265
187
 
266
188
  $$\mathcal{M}\left( E \right) = K \left( \frac{E}{E_0} \right)^{-\alpha} \exp(-E/\beta)$$
267
189
 
268
- ??? abstract "Parameters"
269
- * $\alpha$ : Photon index of the power law $\left[\text{dimensionless}\right]$
270
- * $\beta$ : Folding energy of the exponential cutoff $\left[\text{keV}\right]$
271
- * $E_0$ : Reference energy fixed at 1 keV $\left[ \mathrm{keV}\right]$
272
- * $K$ : Normalization at the reference energy (1 keV) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
190
+ !!! abstract "Parameters"
191
+ * $\alpha$ (`alpha`) $\left[\text{dimensionless}\right]$ : Photon index of the power law
192
+ * $\beta$ (`beta`) $\left[\text{keV}\right]$ : Folding energy of the exponential cutoff
193
+ * $E_0$ $\left[ \mathrm{keV}\right]$ : Reference energy fixed at 1 keV
194
+ * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization
273
195
  """
274
196
 
275
- def continuum(self, energy):
276
- alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
277
- beta = hk.get_parameter("beta", [], float, init=HaikuConstant(15))
278
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
279
-
280
- return norm * energy ** (-alpha) * jnp.exp(-energy / beta)
281
-
282
-
283
- '''
284
- class Diskpbb(AdditiveComponent):
285
- r"""
286
- A multiple blackbody disk model where local disk temperature T(r) is proportional to $$r^{(-p)}$$,
287
- where $$p$$ is a free parameter. The standard disk model, diskbb, is recovered if $$p=0.75$$.
288
- If radial advection is important then $$p<0.75$$.
289
-
290
- $$\\mathcal{M}\\left( E \right) = \frac{2\\pi(\\cos i)r^{2}_{\text{in}}}{pd^2} \\int_{T_{\text{in}}}^{T_{\text{out}}}
291
- \\left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
292
-
293
- ??? abstract "Parameters"
294
- * $\text{norm}$ : $\\cos i(r_{\text{in}}/d)^{2}$,
295
- where $r_{\text{in}}$ is "an apparent" inner disk radius $\\left[\text{km}\right]$,
296
- $d$ the distance to the source in units of $10 \text{kpc}$,
297
- $i$ the angle of the disk ($i=0$ is face-on)
298
- * $p$ : Exponent of the radial dependence of the disk temperature $\\left[\text{dimensionless}\right]$
299
- * $T_{\text{in}}$ : Temperature at inner disk radius $\\left[ \\mathrm{keV}\right]$
300
- """
197
+ def __init__(self):
198
+ self.alpha = nnx.Param(1.7)
199
+ self.beta = nnx.Param(15.0)
200
+ self.norm = nnx.Param(1e-4)
301
201
 
302
202
  def continuum(self, energy):
303
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
304
- p = hk.get_parameter("p", [], float, init=HaikuConstant(0.75))
305
- tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
306
-
307
- # Tout is set to 0 as it is evaluated at R=infinity
308
- def integrand(kT, energy):
309
- return 2.78e-3 * energy**2 * (kT / tin) ** (-2 / p - 1) / (jnp.exp(energy / kT) - 1)
310
-
311
- func_vmapped = jax.vmap(lambda e: integrate_interval(lambda kT: integrand(kT, e), 0, tin, n=51))
312
-
313
- return norm * (0.75 / p) * func_vmapped(energy)
314
- '''
203
+ return self.norm * energy ** (-self.alpha) * jnp.exp(-energy / self.beta)
315
204
 
316
205
 
317
206
  class Diskbb(AdditiveComponent):
318
207
  r"""
319
208
  `Diskpbb` with $p=0.75$
320
209
 
321
- ??? abstract "Parameters"
322
- * $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
323
- * $\text{norm}$ : $\cos i(r_{\text{in}}/d)^{2}$,
324
- where $r_{\text{in}}$ is "an apparent" inner disk radius $\left[\text{km}\right]$,
325
- $d$ the distance to the source in units of $10 \text{kpc}$, $i$ the angle of the disk ($i=0$ is face-on)
210
+ !!! abstract "Parameters"
211
+ * $T_{\text{in}}$ (`Tin`) $\left[ \mathrm{keV}\right]$: Temperature at inner disk radius
212
+ * $\text{norm}$ (`norm`) $\left[\text{dimensionless}\right]$ : $\cos i(r_{\text{in}}/d)^{2}$,
213
+ where $r_{\text{in}}$ is an apparent inner disk radius $\left[\text{km}\right]$,
214
+ $d$ the distance to the source [$10 \text{kpc}$], $i$ the angle of the disk ($i=0$ is face-on)
326
215
  """
327
216
 
217
+ def __init__(self):
218
+ self.Tin = nnx.Param(1.0)
219
+ self.norm = nnx.Param(1e-4)
220
+
328
221
  def continuum(self, energy):
329
222
  p = 0.75
330
223
  tout = 0.0
331
- tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
332
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
333
224
 
334
225
  # Tout is set to 0 as it is evaluated at R=infinity
335
226
  def integrand(kT, e, tin, p):
@@ -337,11 +228,11 @@ class Diskbb(AdditiveComponent):
337
228
 
338
229
  integral = integrate_interval(integrand)
339
230
  return (
340
- norm
231
+ self.norm
341
232
  * 2.78e-3
342
233
  * (0.75 / p)
343
- / tin
344
- * jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
234
+ / self.Tin
235
+ * jnp.vectorize(lambda e: integral(tout, self.Tin, e, self.Tin, p))(energy)
345
236
  )
346
237
 
347
238
 
@@ -354,19 +245,25 @@ class Agauss(AdditiveComponent):
354
245
  $$\mathcal{M}\left( \lambda \right) =
355
246
  \frac{K}{\sigma \sqrt{2 \pi}} \exp\left(\frac{-(\lambda - \lambda_L)^2}{2 \sigma^2}\right)$$
356
247
 
357
- ??? abstract "Parameters"
358
- * $\lambda_L$ : Wavelength of the line in Angström $\left[\unicode{x212B}\right]$
359
- * $\sigma$ : Width of the line width in Angström $\left[\unicode{x212B}\right]$
360
- * $K$ : Normalization $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
248
+ !!! abstract "Parameters"
249
+ * $\lambda_L$ (`lambda_l`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
250
+ * $\sigma$ (`sigma`) $\left[\unicode{x212B}\right]$ : Width of the line width in Angström
251
+ * $K$ (`norm`) $\left[\frac{\unicode{x212B}~\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$: Normalization
361
252
  """
362
253
 
254
+ def __init__(self):
255
+ self.lambda_l = nnx.Param(12.0)
256
+ self.sigma = nnx.Param(1e-2)
257
+ self.norm = nnx.Param(1.0)
258
+
363
259
  def continuum(self, energy) -> (jax.Array, jax.Array):
364
260
  hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
365
- line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
366
- sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
367
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
368
261
 
369
- return norm * jsp.stats.norm.pdf(hc / energy, loc=line_wavelength, scale=sigma)
262
+ return self.norm * jsp.stats.norm.pdf(
263
+ hc / energy,
264
+ loc=jnp.asarray(self.lambda_l, dtype=jnp.float64),
265
+ scale=jnp.asarray(self.sigma, dtype=jnp.float64),
266
+ )
370
267
 
371
268
 
372
269
  class Zagauss(AdditiveComponent):
@@ -377,24 +274,32 @@ class Zagauss(AdditiveComponent):
377
274
  $$\mathcal{M}\left( \lambda \right) =
378
275
  \frac{K (1+z)}{\sigma \sqrt{2 \pi}} \exp\left(\frac{-(\lambda/(1+z) - \lambda_L)^2}{2 \sigma^2}\right)$$
379
276
 
380
- ??? abstract "Parameters"
381
- * $\lambda_L$ : Wavelength of the line in Angström $\left[\text{\AA}\right]$
382
- * $\sigma$ : Width of the line width in Angström $\left[\text{\AA}\right]$
383
- * $K$ : Normalization $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
384
- * $z$ : Redshift [dimensionless]
277
+ !!! abstract "Parameters"
278
+ * $\lambda_L$ (`lambda_l`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
279
+ * $\sigma$ (`sigma`) $\left[\unicode{x212B}\right]$ : Width of the line width in Angström
280
+ * $z$ (`redshift`) $\left[\text{dimensionless}\right]$ : Redshift
281
+ * $K$ (`norm`) $\left[\frac{\unicode{x212B}~\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization
385
282
  """
386
283
 
284
+ def __init__(self):
285
+ self.lambda_l = nnx.Param(12.0)
286
+ self.sigma = nnx.Param(1e-2)
287
+ self.redshift = nnx.Param(0.0)
288
+ self.norm = nnx.Param(1.0)
289
+
387
290
  def continuum(self, energy) -> (jax.Array, jax.Array):
388
291
  hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
389
- line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
390
- sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
391
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
392
- redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
292
+
293
+ redshift = self.redshift
393
294
 
394
295
  return (
395
- norm
296
+ self.norm
396
297
  * (1 + redshift)
397
- * jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
298
+ * jsp.stats.norm.pdf(
299
+ (hc / energy) / (1 + redshift),
300
+ loc=jnp.asarray(self.lambda_l, dtype=jnp.float64),
301
+ scale=jnp.asarray(self.sigma, dtype=jnp.float64),
302
+ )
398
303
  )
399
304
 
400
305
 
@@ -405,21 +310,24 @@ class Zgauss(AdditiveComponent):
405
310
  $$\mathcal{M}\left( E \right) =
406
311
  \frac{K}{(1+z) \sigma \sqrt{2 \pi}}\exp\left(\frac{-(E(1+z)-E_L)^2}{2\sigma^2}\right)$$
407
312
 
408
- ??? abstract "Parameters"
409
- * $E_L$ : Energy of the line $\left[\text{keV}\right]$
410
- * $\sigma$ : Width of the line $\left[\text{keV}\right]$
411
- * $K$ : Normalization $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$
412
- * $z$ : Redshift [dimensionless]
313
+ !!! abstract "Parameters"
314
+ * $E_L$ (`E_l`) $\left[\text{keV}\right]$ : Energy of the line
315
+ * $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Width of the line
316
+ * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
317
+ * $z$ (`redshift`) $\left[\text{dimensionless}\right]$ : Redshift
413
318
  """
414
319
 
415
- def continuum(self, energy) -> (jax.Array, jax.Array):
416
- line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
417
- sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
418
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
419
- redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
320
+ def __init__(self):
321
+ self.E_l = nnx.Param(2.0)
322
+ self.sigma = nnx.Param(1e-2)
323
+ self.redshift = nnx.Param(0.0)
324
+ self.norm = nnx.Param(1.0)
420
325
 
421
- return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
422
- energy * (1 + redshift), loc=line_energy, scale=sigma
326
+ def continuum(self, energy) -> (jax.Array, jax.Array):
327
+ return (self.norm / (1 + self.redshift)) * jsp.stats.norm.pdf(
328
+ energy * (1 + self.redshift),
329
+ loc=jnp.asarray(self.E_l, dtype=jnp.float64),
330
+ scale=jnp.asarray(self.sigma, dtype=jnp.float64),
423
331
  )
424
332
 
425
333
 
@@ -430,50 +338,57 @@ class NSatmos(AdditiveComponent):
430
338
  !!! warning
431
339
  The boundary case of $R_{\text{NS}} < 1.125 R_{\text{S}}$ is handled with a null flux instead of a constant value as in `XSPEC`.
432
340
 
433
- ??? abstract "Parameters"
434
- * $T_{eff}$ : Effective temperature at the surface in K (No redshift applied)
435
- * $M_{ns}$ : Mass of the NS in solar masses
436
- * $R_∞$ : Radius at infinity (modulated by gravitational effects) in km
437
- * $D$ : Distance to the neutron star in kpc
438
- * norm : fraction of the neutron star surface emitting
341
+ !!! abstract "Parameters"
342
+ * $T_{eff}$ (`Tinf`) $\left[\text{Kelvin}\right]$ : Effective temperature at the surface (No redshift applied)
343
+ * $M_{ns}$ (`mass`) $\left[M_{\odot}\right]$ : Mass of the NS
344
+ * $R_∞$ (`radius`) $\left[\text{km}\right]$ : Radius at infinity (modulated by gravitational effects)
345
+ * $D$ (`distance`) $\left[\text{kpc}\right]$ : Distance to the neutron star
346
+ * norm (`norm`) $\left[\text{dimensionless}\right]$ : fraction of the neutron star surface emitting
439
347
  """
440
348
 
441
- def __init__(self, *args, **kwargs):
442
- super().__init__(*args, **kwargs)
349
+ def __init__(self):
350
+ self.Tinf = nnx.Param(6.0)
351
+ self.mass = nnx.Param(1.4)
352
+ self.radius = nnx.Param(10.0)
353
+ self.distance = nnx.Param(10.0)
354
+ self.norm = nnx.Param(1.0)
355
+
443
356
  entry_table = Table.read(table_manager.fetch("nsatmosdata.fits"), 1)
444
357
 
445
- # Extract the table values. All this code could be summarize in two lines if we reformat the nsatmosdata.fits table
446
- self.tab_temperature = np.asarray(entry_table["TEMP"][0], dtype=float) # Logarithmic value
447
- self.tab_gravity = np.asarray(entry_table["GRAVITY"][0], dtype=float) # Logarithmic value
448
- self.tab_mucrit = np.asarray(entry_table["MUCRIT"][0], dtype=float)
449
- self.tab_energy = np.asarray(entry_table["ENERGY"][0], dtype=float)
450
- self.tab_flux_flat = Table.read(table_manager.fetch("nsatmosdata.fits"), 2)["FLUX"]
358
+ # Extract the table values. All this code could be summarized in two lines if we reformat the nsatmosdata.fits table
359
+ tab_temperature = np.asarray(entry_table["TEMP"][0], dtype=float) # Logarithmic value
360
+ tab_gravity = np.asarray(entry_table["GRAVITY"][0], dtype=float) # Logarithmic value
361
+ tab_mucrit = np.asarray(entry_table["MUCRIT"][0], dtype=float)
362
+ tab_energy = np.asarray(entry_table["ENERGY"][0], dtype=float)
363
+ tab_flux_flat = Table.read(table_manager.fetch("nsatmosdata.fits"), 2)["FLUX"]
451
364
 
452
365
  tab_flux = np.empty(
453
366
  (
454
- self.tab_temperature.size,
455
- self.tab_gravity.size,
456
- self.tab_mucrit.size,
457
- self.tab_energy.size,
367
+ tab_temperature.size,
368
+ tab_gravity.size,
369
+ tab_mucrit.size,
370
+ tab_energy.size,
458
371
  )
459
372
  )
460
373
 
461
- for i in range(len(self.tab_temperature)):
462
- for j in range(len(self.tab_gravity)):
463
- for k in range(len(self.tab_mucrit)):
374
+ for i in range(len(tab_temperature)):
375
+ for j in range(len(tab_gravity)):
376
+ for k in range(len(tab_mucrit)):
464
377
  tab_flux[i, j, k] = np.array(
465
- self.tab_flux_flat[
466
- i * len(self.tab_gravity) * len(self.tab_mucrit)
467
- + j * len(self.tab_mucrit)
468
- + k
378
+ tab_flux_flat[
379
+ i * len(tab_gravity) * len(tab_mucrit) + j * len(tab_mucrit) + k
469
380
  ]
470
381
  )
471
382
 
472
- self.tab_flux = np.asarray(tab_flux, dtype=float)
383
+ tab_flux = np.asarray(tab_flux, dtype=float)
473
384
 
474
- def interp_flux_func(self, temperature_log, gravity_log, mu):
475
- # Interpolate the tables to get the flux on the tabulated energy grid
385
+ self.tab_temperature = nnx.Variable(tab_temperature)
386
+ self.tab_gravity = nnx.Variable(tab_gravity)
387
+ self.tab_mucrit = nnx.Variable(tab_mucrit)
388
+ self.tab_energy = nnx.Variable(tab_energy)
389
+ self.tab_flux = nnx.Variable(tab_flux)
476
390
 
391
+ def interp_flux_func(self, temperature_log, gravity_log, mu):
477
392
  return interpax.interp3d(
478
393
  10.0**temperature_log,
479
394
  10.0**gravity_log,
@@ -485,19 +400,18 @@ class NSatmos(AdditiveComponent):
485
400
  method="linear",
486
401
  )
487
402
 
403
+ @partial(jnp.vectorize, excluded=(0,))
488
404
  def continuum(self, energy):
489
- temp_log = hk.get_parameter(
490
- "Tinf", [], float, init=HaikuConstant(6.0)
491
- ) # log10 of temperature in Kelvin
492
-
493
- # 'Tinf': temp_log, # 5 to 6.5
494
- # 'M': mass, # 0.5 to 3
495
- # 'Rns': radius, # 5 to 30
405
+ # log10 of temperature in Kelvin
406
+ # 'Tinf': temp_log, # 5 to 6.5
407
+ # 'M': mass, # 0.5 to 3
408
+ # 'Rns': radius, # 5 to 30
496
409
 
497
- mass = hk.get_parameter("M", [], float, init=HaikuConstant(1.4))
498
- radius = hk.get_parameter("Rns", [], float, init=HaikuConstant(10.0))
499
- distance = hk.get_parameter("dns", [], float, init=HaikuConstant(10.0))
500
- norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1.0))
410
+ temp_log = self.Tinf
411
+ mass = self.mass
412
+ radius = self.radius
413
+ distance = self.distance
414
+ norm = self.norm
501
415
 
502
416
  # Derive parameters usable to retrive value in the flux table
503
417
  Rcgs = 1e5 * radius # Radius in cgs
@@ -519,7 +433,7 @@ class NSatmos(AdditiveComponent):
519
433
 
520
434
  # Interpolate the flux table to get the flux at the surface
521
435
 
522
- flux = jax.jit(self.interp_flux_func)(temp_log, gravity_log, cmu)
436
+ flux = self.interp_flux_func(temp_log, gravity_log, cmu)
523
437
 
524
438
  # Rescale the photon energies and fluxes back to the correct local temperature
525
439
  Tfactor = 10.0 ** (temp_log - 6.0)
@@ -542,3 +456,76 @@ class NSatmos(AdditiveComponent):
542
456
  )
543
457
 
544
458
  return jax.lax.select(r_normalized < 1.125, jnp.zeros_like(true_flux), true_flux)
459
+
460
+
461
+ class Band(AdditiveComponent):
462
+ r"""
463
+ A Band function model
464
+
465
+ $$
466
+ \mathcal{M}(E) = \begin{cases} K \left( \frac{E}{100 \, \text{keV}}\right)^{\alpha_1}\exp(-\frac{E}{E_c}) &
467
+ \text{if $E < E_c (\alpha_1 - \alpha_2)$} \\
468
+ K \left[ (\alpha_1 - \alpha_2) \frac{E_c}{100 \, \text{keV}} \right]^{\alpha_1-\alpha_2} \left( \frac{E}{100 \, \text{keV}}\right)^{\alpha_2} \exp(-(\alpha_1 - \alpha_2)) & \text{if $E > E_c (\alpha_1 - \alpha_2)$}
469
+ \end{cases}
470
+ $$
471
+
472
+ !!! abstract "Parameters"
473
+ * $\alpha_1$ (`alpha1`) $\left[\text{dimensionless}\right]$ : First powerlaw index
474
+ * $\alpha_2$ (`alpha2`) $\left[\text{dimensionless}\right]$ : Second powerlaw index
475
+ * $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Radius at infinity (modulated by gravitational effects)
476
+ * norm (`norm`) $\left[\frac{\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization at the reference energy (100 keV)
477
+ """
478
+
479
+ def __init__(self):
480
+ self.alpha1 = nnx.Param(-1.0)
481
+ self.alpha2 = nnx.Param(-2.0)
482
+ self.Ec = nnx.Param(200.0)
483
+ self.norm = nnx.Param(1e-4)
484
+
485
+ def continuum(self, energy):
486
+ Epivot = 100.0
487
+ alpha_diff = jnp.asarray(self.alpha1) - jnp.asarray(self.alpha2)
488
+
489
+ spectrum = jnp.where(
490
+ energy < self.Ec * (alpha_diff),
491
+ (energy / Epivot) ** self.alpha1 * jnp.exp(-energy / self.Ec),
492
+ (alpha_diff * (self.Ec / Epivot)) ** (alpha_diff)
493
+ * (energy / 100) ** self.alpha2
494
+ * jnp.exp(-alpha_diff),
495
+ )
496
+
497
+ return self.norm * spectrum
498
+
499
+
500
+ '''
501
+ class Diskpbb(AdditiveComponent):
502
+ r"""
503
+ A multiple blackbody disk model where local disk temperature T(r) is proportional to $$r^{(-p)}$$,
504
+ where $$p$$ is a free parameter. The standard disk model, diskbb, is recovered if $$p=0.75$$.
505
+ If radial advection is important then $$p<0.75$$.
506
+
507
+ $$\\mathcal{M}\\left( E \right) = \frac{2\\pi(\\cos i)r^{2}_{\text{in}}}{pd^2} \\int_{T_{\text{in}}}^{T_{\text{out}}}
508
+ \\left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
509
+
510
+ ??? abstract "Parameters"
511
+ * $\text{norm}$ : $\\cos i(r_{\text{in}}/d)^{2}$,
512
+ where $r_{\text{in}}$ is "an apparent" inner disk radius $\\left[\text{km}\right]$,
513
+ $d$ the distance to the source in units of $10 \text{kpc}$,
514
+ $i$ the angle of the disk ($i=0$ is face-on)
515
+ * $p$ : Exponent of the radial dependence of the disk temperature $\\left[\text{dimensionless}\right]$
516
+ * $T_{\text{in}}$ : Temperature at inner disk radius $\\left[ \\mathrm{keV}\right]$
517
+ """
518
+
519
+ def continuum(self, energy):
520
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
521
+ p = hk.get_parameter("p", [], float, init=HaikuConstant(0.75))
522
+ tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
523
+
524
+ # Tout is set to 0 as it is evaluated at R=infinity
525
+ def integrand(kT, energy):
526
+ return 2.78e-3 * energy**2 * (kT / tin) ** (-2 / p - 1) / (jnp.exp(energy / kT) - 1)
527
+
528
+ func_vmapped = jax.vmap(lambda e: integrate_interval(lambda kT: integrand(kT, e), 0, tin, n=51))
529
+
530
+ return norm * (0.75 / p) * func_vmapped(energy)
531
+ '''