jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/model/additive.py CHANGED
@@ -1,16 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
-
3
+ import astropy.constants
4
+ import astropy.units as u
4
5
  import haiku as hk
6
+ import interpax
5
7
  import jax
6
8
  import jax.numpy as jnp
7
9
  import jax.scipy as jsp
8
- import astropy.units as u
9
- import astropy.constants
10
+ import numpy as np
11
+
12
+ from astropy.table import Table
10
13
  from haiku.initializers import Constant as HaikuConstant
14
+
11
15
  from ..util.integrate import integrate_interval
16
+ from ..util.online_storage import table_manager
17
+
18
+ # from ._additive.apec import APEC
12
19
  from .abc import AdditiveComponent
13
- # from ._additive.apec import APEC # noqa: F401
14
20
 
15
21
 
16
22
  class Powerlaw(AdditiveComponent):
@@ -26,8 +32,8 @@ class Powerlaw(AdditiveComponent):
26
32
  """
27
33
 
28
34
  def continuum(self, energy):
29
- alpha = hk.get_parameter("alpha", [], init=HaikuConstant(1.3))
30
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1e-4))
35
+ alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
36
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
31
37
 
32
38
  return norm * energy ** (-alpha)
33
39
 
@@ -43,12 +49,12 @@ class AdditiveConstant(AdditiveComponent):
43
49
  """
44
50
 
45
51
  def continuum(self, energy):
46
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
52
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
47
53
 
48
54
  return norm * jnp.ones_like(energy)
49
55
 
50
56
  def primitive(self, energy):
51
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
57
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
52
58
 
53
59
  return norm * energy
54
60
 
@@ -66,9 +72,9 @@ class Lorentz(AdditiveComponent):
66
72
  """
67
73
 
68
74
  def continuum(self, energy):
69
- line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
70
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
71
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
75
+ line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
76
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
77
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
72
78
 
73
79
  return norm * sigma / (2 * jnp.pi) / ((energy - line_energy) ** 2 + (sigma / 2) ** 2)
74
80
 
@@ -91,9 +97,9 @@ class Logparabola(AdditiveComponent):
91
97
 
92
98
  # TODO : conform with xspec definition
93
99
  def continuum(self, energy):
94
- a = hk.get_parameter("a", [], init=HaikuConstant(11 / 3))
95
- b = hk.get_parameter("b", [], init=HaikuConstant(0.2))
96
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
100
+ a = hk.get_parameter("a", [], float, init=HaikuConstant(11 / 3))
101
+ b = hk.get_parameter("b", [], float, init=HaikuConstant(0.2))
102
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
97
103
 
98
104
  return norm * energy ** (-(a + b * jnp.log(energy)))
99
105
 
@@ -111,8 +117,8 @@ class Blackbody(AdditiveComponent):
111
117
  """
112
118
 
113
119
  def continuum(self, energy):
114
- kT = hk.get_parameter("kT", [], init=HaikuConstant(11 / 3))
115
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
120
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
121
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
116
122
 
117
123
  return norm * 8.0525 * energy**2 / ((kT**4) * (jnp.exp(energy / kT) - 1))
118
124
 
@@ -130,8 +136,8 @@ class Blackbodyrad(AdditiveComponent):
130
136
  """
131
137
 
132
138
  def continuum(self, energy):
133
- kT = hk.get_parameter("kT", [], init=HaikuConstant(11 / 3))
134
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
139
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
140
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
135
141
 
136
142
  return norm * 1.0344e-3 * energy**2 / (jnp.exp(energy / kT) - 1)
137
143
 
@@ -150,9 +156,9 @@ class Gauss(AdditiveComponent):
150
156
  """
151
157
 
152
158
  def continuum(self, energy) -> (jax.Array, jax.Array):
153
- line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
154
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
155
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
159
+ line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
160
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
161
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
156
162
 
157
163
  return norm * jsp.stats.norm.pdf(energy, loc=line_energy, scale=sigma)
158
164
 
@@ -198,9 +204,9 @@ class APEC(AdditiveComponent):
198
204
  )
199
205
 
200
206
  def mono_fine_structure(self, e_low, e_high) -> (jax.Array, jax.Array):
201
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
202
- kT = hk.get_parameter("kT", [], init=HaikuConstant(1))
203
- Z = hk.get_parameter("Z", [], init=HaikuConstant(1))
207
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
208
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
209
+ Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
204
210
 
205
211
  idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1
206
212
 
@@ -229,9 +235,9 @@ class APEC(AdditiveComponent):
229
235
 
230
236
  @partial(jnp.vectorize, excluded=(0,))
231
237
  def continuum(self, energy):
232
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
233
- kT = hk.get_parameter("kT", [], init=HaikuConstant(1))
234
- Z = hk.get_parameter("Z", [], init=HaikuConstant(1))
238
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
239
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
240
+ Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
235
241
 
236
242
  idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1 # index of left value
237
243
 
@@ -269,9 +275,9 @@ class Cutoffpl(AdditiveComponent):
269
275
  """
270
276
 
271
277
  def continuum(self, energy):
272
- alpha = hk.get_parameter("alpha", [], init=HaikuConstant(1.3))
273
- beta = hk.get_parameter("beta", [], init=HaikuConstant(15))
274
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1e-4))
278
+ alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
279
+ beta = hk.get_parameter("beta", [], float, init=HaikuConstant(15))
280
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
275
281
 
276
282
  return norm * energy ** (-alpha) * jnp.exp(-energy / beta)
277
283
 
@@ -283,22 +289,22 @@ class Diskpbb(AdditiveComponent):
283
289
  where $$p$$ is a free parameter. The standard disk model, diskbb, is recovered if $$p=0.75$$.
284
290
  If radial advection is important then $$p<0.75$$.
285
291
 
286
- $$\mathcal{M}\left( E \right) = \frac{2\pi(\cos i)r^{2}_{\text{in}}}{pd^2} \int_{T_{\text{in}}}^{T_{\text{out}}}
287
- \left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
292
+ $$\\mathcal{M}\\left( E \right) = \frac{2\\pi(\\cos i)r^{2}_{\text{in}}}{pd^2} \\int_{T_{\text{in}}}^{T_{\text{out}}}
293
+ \\left( \frac{T}{T_{\text{in}}} \right)^{-2/p-1} \text{bbody}(E,T) \frac{dT}{T_{\text{in}}}$$
288
294
 
289
295
  ??? abstract "Parameters"
290
- * $\text{norm}$ : $\cos i(r_{\text{in}}/d)^{2}$,
291
- where $r_{\text{in}}$ is "an apparent" inner disk radius $\left[\text{km}\right]$,
296
+ * $\text{norm}$ : $\\cos i(r_{\text{in}}/d)^{2}$,
297
+ where $r_{\text{in}}$ is "an apparent" inner disk radius $\\left[\text{km}\right]$,
292
298
  $d$ the distance to the source in units of $10 \text{kpc}$,
293
299
  $i$ the angle of the disk ($i=0$ is face-on)
294
- * $p$ : Exponent of the radial dependence of the disk temperature $\left[\text{dimensionless}\right]$
295
- * $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
300
+ * $p$ : Exponent of the radial dependence of the disk temperature $\\left[\text{dimensionless}\right]$
301
+ * $T_{\text{in}}$ : Temperature at inner disk radius $\\left[ \\mathrm{keV}\right]$
296
302
  """
297
303
 
298
304
  def continuum(self, energy):
299
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
300
- p = hk.get_parameter("p", [], init=HaikuConstant(0.75))
301
- tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
305
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
306
+ p = hk.get_parameter("p", [], float, init=HaikuConstant(0.75))
307
+ tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
302
308
 
303
309
  # Tout is set to 0 as it is evaluated at R=infinity
304
310
  def integrand(kT, energy):
@@ -324,15 +330,21 @@ class Diskbb(AdditiveComponent):
324
330
  def continuum(self, energy):
325
331
  p = 0.75
326
332
  tout = 0.0
327
- tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
328
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
333
+ tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
334
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
329
335
 
330
336
  # Tout is set to 0 as it is evaluated at R=infinity
331
337
  def integrand(kT, e, tin, p):
332
338
  return e**2 * (kT / tin) ** (-2 / p - 1) / (jnp.exp(e / kT) - 1)
333
339
 
334
340
  integral = integrate_interval(integrand)
335
- return norm * 2.78e-3 * (0.75 / p) / tin * jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
341
+ return (
342
+ norm
343
+ * 2.78e-3
344
+ * (0.75 / p)
345
+ / tin
346
+ * jnp.vectorize(lambda e: integral(tout, tin, e, tin, p))(energy)
347
+ )
336
348
 
337
349
 
338
350
  class Agauss(AdditiveComponent):
@@ -352,9 +364,9 @@ class Agauss(AdditiveComponent):
352
364
 
353
365
  def continuum(self, energy) -> (jax.Array, jax.Array):
354
366
  hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
355
- line_wavelength = hk.get_parameter("Lambda_l", [], init=HaikuConstant(hc))
356
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(0.001))
357
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
367
+ line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
368
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
369
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
358
370
 
359
371
  return norm * jsp.stats.norm.pdf(hc / energy, loc=line_wavelength, scale=sigma)
360
372
 
@@ -376,12 +388,16 @@ class Zagauss(AdditiveComponent):
376
388
 
377
389
  def continuum(self, energy) -> (jax.Array, jax.Array):
378
390
  hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
379
- line_wavelength = hk.get_parameter("Lambda_l", [], init=HaikuConstant(hc))
380
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(0.001))
381
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
382
- redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
391
+ line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
392
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
393
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
394
+ redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
383
395
 
384
- return norm * (1 + redshift) * jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
396
+ return (
397
+ norm
398
+ * (1 + redshift)
399
+ * jsp.stats.norm.pdf((hc / energy) / (1 + redshift), loc=line_wavelength, scale=sigma)
400
+ )
385
401
 
386
402
 
387
403
  class Zgauss(AdditiveComponent):
@@ -399,9 +415,132 @@ class Zgauss(AdditiveComponent):
399
415
  """
400
416
 
401
417
  def continuum(self, energy) -> (jax.Array, jax.Array):
402
- line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
403
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
404
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
405
- redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
418
+ line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
419
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
420
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
421
+ redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
422
+
423
+ return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
424
+ energy * (1 + redshift), loc=line_energy, scale=sigma
425
+ )
426
+
427
+
428
+ class NSatmos(AdditiveComponent):
429
+ r"""
430
+ A neutron star atmosphere model based on the `NSATMOS` model from `XSPEC`. See [this page](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/node205.html)
431
+
432
+ !!! warning
433
+ The boundary case of $R_{\text{NS}} < 1.125 R_{\text{S}}$ is handled with a null flux instead of a constant value as in `XSPEC`.
434
+
435
+ ??? abstract "Parameters"
436
+ * $T_{eff}$ : Effective temperature at the surface in K (No redshift applied)
437
+ * $M_{ns}$ : Mass of the NS in solar masses
438
+ * $R_∞$ : Radius at infinity (modulated by gravitational effects) in km
439
+ * $D$ : Distance to the neutron star in kpc
440
+ * norm : fraction of the neutron star surface emitting
441
+ """
442
+
443
+ def __init__(self, *args, **kwargs):
444
+ super().__init__(*args, **kwargs)
445
+ entry_table = Table.read(table_manager.fetch("nsatmosdata.fits"), 1)
446
+
447
+ # Extract the table values. All this code could be summarize in two lines if we reformat the nsatmosdata.fits table
448
+ self.tab_temperature = np.asarray(entry_table["TEMP"][0], dtype=float) # Logarithmic value
449
+ self.tab_gravity = np.asarray(entry_table["GRAVITY"][0], dtype=float) # Logarithmic value
450
+ self.tab_mucrit = np.asarray(entry_table["MUCRIT"][0], dtype=float)
451
+ self.tab_energy = np.asarray(entry_table["ENERGY"][0], dtype=float)
452
+ self.tab_flux_flat = Table.read(table_manager.fetch("nsatmosdata.fits"), 2)["FLUX"]
453
+
454
+ tab_flux = np.empty(
455
+ (
456
+ self.tab_temperature.size,
457
+ self.tab_gravity.size,
458
+ self.tab_mucrit.size,
459
+ self.tab_energy.size,
460
+ )
461
+ )
462
+
463
+ for i in range(len(self.tab_temperature)):
464
+ for j in range(len(self.tab_gravity)):
465
+ for k in range(len(self.tab_mucrit)):
466
+ tab_flux[i, j, k] = np.array(
467
+ self.tab_flux_flat[
468
+ i * len(self.tab_gravity) * len(self.tab_mucrit)
469
+ + j * len(self.tab_mucrit)
470
+ + k
471
+ ]
472
+ )
473
+
474
+ self.tab_flux = np.asarray(tab_flux, dtype=float)
475
+
476
+ def interp_flux_func(self, temperature_log, gravity_log, mu):
477
+ # Interpolate the tables to get the flux on the tabulated energy grid
478
+
479
+ return interpax.interp3d(
480
+ 10.0**temperature_log,
481
+ 10.0**gravity_log,
482
+ mu,
483
+ 10.0**self.tab_temperature,
484
+ 10.0**self.tab_gravity,
485
+ self.tab_mucrit,
486
+ self.tab_flux,
487
+ method="linear",
488
+ )
489
+
490
+ def continuum(self, energy):
491
+ temp_log = hk.get_parameter(
492
+ "Tinf", [], float, init=HaikuConstant(6.0)
493
+ ) # log10 of temperature in Kelvin
494
+
495
+ # 'Tinf': temp_log, # 5 to 6.5
496
+ # 'M': mass, # 0.5 to 3
497
+ # 'Rns': radius, # 5 to 30
498
+
499
+ mass = hk.get_parameter("M", [], float, init=HaikuConstant(1.4))
500
+ radius = hk.get_parameter("Rns", [], float, init=HaikuConstant(10.0))
501
+ distance = hk.get_parameter("dns", [], float, init=HaikuConstant(10.0))
502
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1.0))
503
+
504
+ # Derive parameters usable to retrive value in the flux table
505
+ Rcgs = 1e5 * radius # Radius in cgs
506
+ r_schwarzschild = 2.95e5 * mass # Schwarzschild radius in cgs
507
+ r_normalized = Rcgs / r_schwarzschild # Ratio of the radius to the Schwarzschild radius
508
+ r_over_Dsql = 2 * jnp.log10(
509
+ Rcgs / (3.09e21 * distance)
510
+ ) # Log( (R/D)**2 ), 3.09e21 constant transforms radius in cgs to kpc
511
+ zred1 = 1 / jnp.sqrt(1 - (1 / r_normalized)) # Gravitational redshift
512
+ gravity = (6.67e-8 * 1.99e33 * mass / Rcgs**2) * zred1 # Gravity field g in cgs
513
+ gravity_log = jnp.log10(
514
+ gravity
515
+ ) # Log gravity because this is the format given in the table
516
+
517
+ # Not sure about mu yet, but it is linked to causality
518
+ cmu = jnp.where(
519
+ r_normalized < 1.5, jnp.sqrt(1.0 - 6.75 / r_normalized**2 + 6.75 / r_normalized**3), 0.0
520
+ )
521
+
522
+ # Interpolate the flux table to get the flux at the surface
523
+
524
+ flux = jax.jit(self.interp_flux_func)(temp_log, gravity_log, cmu)
525
+
526
+ # Rescale the photon energies and fluxes back to the correct local temperature
527
+ Tfactor = 10.0 ** (temp_log - 6.0)
528
+ fluxshift = 3.0 * (temp_log - 6.0)
529
+ E = self.tab_energy * Tfactor
530
+ flux += fluxshift
531
+
532
+ # Rescale applying redshift
533
+ fluxshift = -jnp.log10(zred1)
534
+ E = E / zred1
535
+ flux += fluxshift
536
+
537
+ # Convert to counts/keV (which corresponds to dividing by1.602e-9*EkeV)
538
+ # Multiply by the area of the star, and calculate the count rate at the observer
539
+ flux += r_over_Dsql
540
+ counts = 10.0 ** (flux - jnp.log10(1.602e-9) - jnp.log10(E))
541
+
542
+ true_flux = norm * jnp.exp(
543
+ interpax.interp1d(jnp.log(energy), jnp.log(E), jnp.log(counts), method="linear")
544
+ )
406
545
 
407
- return (norm / (1 + redshift)) * jsp.stats.norm.pdf(energy * (1 + redshift), loc=line_energy, scale=sigma)
546
+ return jax.lax.select(r_normalized < 1.125, jnp.zeros_like(true_flux), true_flux)
@@ -1,9 +1,11 @@
1
1
  from abc import ABC, abstractmethod
2
- from tinygp import kernels, GaussianProcess
3
- from jax.scipy.integrate import trapezoid
4
- import numpyro.distributions as dist
2
+
5
3
  import jax.numpy as jnp
6
4
  import numpyro
5
+ import numpyro.distributions as dist
6
+
7
+ from jax.scipy.integrate import trapezoid
8
+ from tinygp import GaussianProcess, kernels
7
9
 
8
10
 
9
11
  class BackgroundModel(ABC):
@@ -33,7 +35,8 @@ class SubtractedBackground(BackgroundModel):
33
35
 
34
36
  """
35
37
 
36
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
38
+ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
39
+ _, observed_counts = obs.out_energies, obs.folded_background.data
37
40
  numpyro.deterministic(f"{name}", observed_counts)
38
41
 
39
42
  return jnp.zeros_like(observed_counts)
@@ -49,32 +52,37 @@ class BackgroundWithError(BackgroundModel):
49
52
  but slower since it performs the fit using MCMC instead of analytical solution.
50
53
  """
51
54
 
52
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
55
+ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
53
56
  # Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
57
+ _, observed_counts = obs.out_energies, obs.folded_background.data
54
58
  alpha = observed_counts + 1
55
59
  beta = 1
56
60
  countrate = numpyro.sample(f"{name}_params", dist.Gamma(alpha, rate=beta))
57
61
 
58
62
  with numpyro.plate(f"{name}_plate", len(observed_counts)):
59
- numpyro.sample(f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None)
63
+ numpyro.sample(
64
+ f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
65
+ )
60
66
 
61
67
  return countrate
62
68
 
63
69
 
64
70
  '''
71
+ # TODO: Implement this class and sample it with Gibbs Sampling
72
+
65
73
  class ConjugateBackground(BackgroundModel):
66
74
  r"""
67
- This class fit an expected rate $\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
75
+ This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
68
76
  distribution, we can analytically derive the posterior as a Negative binomial distribution.
69
77
 
70
- $$ p(\lambda_{\text{Bkg}}) \sim \Gamma \left( \alpha, \beta \right) \implies
71
- p\left(\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \sim \text{NB}\left(\alpha, \frac{\beta}{\beta +1}
78
+ $$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
79
+ p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
72
80
  \right) $$
73
81
 
74
82
  !!! info
75
83
  Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
76
- the prior distribution is such that $\mathbb{E}[\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
77
- $\text{Var}[\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
84
+ the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
85
+ $\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
78
86
  counts are 0, and add a small scatter even if the measured background is effectively null.
79
87
 
80
88
  ??? abstract "References"
@@ -97,6 +105,19 @@ class ConjugateBackground(BackgroundModel):
97
105
  return countrate
98
106
  '''
99
107
 
108
+ """
109
+ class SpectralBackgroundModel(BackgroundModel):
110
+ # I should pass the current spectral model as an argument to the background model
111
+ # In the numpyro model function
112
+ def __init__(self, model, prior):
113
+ self.model = model
114
+ self.prior = prior
115
+
116
+ def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
117
+ #TODO : keep the sparsification from top model
118
+ transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
119
+ """
120
+
100
121
 
101
122
  class GaussianProcessBackground(BackgroundModel):
102
123
  """
@@ -104,7 +125,13 @@ class GaussianProcessBackground(BackgroundModel):
104
125
  [`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
105
126
  """
106
127
 
107
- def __init__(self, e_min: float, e_max: float, n_nodes: int = 30, kernel: kernels.Kernel = kernels.Matern52):
128
+ def __init__(
129
+ self,
130
+ e_min: float,
131
+ e_max: float,
132
+ n_nodes: int = 30,
133
+ kernel: kernels.Kernel = kernels.Matern52,
134
+ ):
108
135
  """
109
136
  Build the Gaussian Process background model.
110
137
 
@@ -119,7 +146,7 @@ class GaussianProcessBackground(BackgroundModel):
119
146
  self.n_nodes = n_nodes
120
147
  self.kernel = kernel
121
148
 
122
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
149
+ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
123
150
  """
124
151
  Build the model for the background.
125
152
 
@@ -129,9 +156,12 @@ class GaussianProcessBackground(BackgroundModel):
129
156
  name: The name of the background model for parameters disambiguation.
130
157
  observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
131
158
  """
159
+ energy, observed_counts = obs.out_energies, obs.folded_background.data
132
160
 
133
161
  if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
134
- raise RuntimeError("More nodes than channels in the observation associated with GaussianProcessBackground.")
162
+ raise RuntimeError(
163
+ "More nodes than channels in the observation associated with GaussianProcessBackground."
164
+ )
135
165
 
136
166
  # The parameters of the GP model
137
167
  mean = numpyro.sample(f"{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0))
@@ -144,13 +174,17 @@ class GaussianProcessBackground(BackgroundModel):
144
174
  gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
145
175
 
146
176
  log_rate = numpyro.sample(f"_{name}_log_rate_nodes", gp.numpyro_dist())
147
- interp_count_rate = jnp.exp(jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate))
177
+ interp_count_rate = jnp.exp(
178
+ jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
179
+ )
148
180
  count_rate = trapezoid(interp_count_rate, energy, axis=0)
149
181
 
150
182
  # Finally, our observation model is Poisson
151
183
  with numpyro.plate(f"{name}_plate", len(observed_counts)):
152
184
  # TODO : change to Poisson Likelihood when there is no background model
153
185
  # TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
154
- numpyro.sample(f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None)
186
+ numpyro.sample(
187
+ f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
188
+ )
155
189
 
156
190
  return count_rate