jaxspec 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/model/additive.py CHANGED
@@ -3,13 +3,17 @@ from __future__ import annotations
3
3
  import astropy.constants
4
4
  import astropy.units as u
5
5
  import haiku as hk
6
+ import interpax
6
7
  import jax
7
8
  import jax.numpy as jnp
8
9
  import jax.scipy as jsp
10
+ import numpy as np
9
11
 
12
+ from astropy.table import Table
10
13
  from haiku.initializers import Constant as HaikuConstant
11
14
 
12
15
  from ..util.integrate import integrate_interval
16
+ from ..util.online_storage import table_manager
13
17
 
14
18
  # from ._additive.apec import APEC
15
19
  from .abc import AdditiveComponent
@@ -28,8 +32,8 @@ class Powerlaw(AdditiveComponent):
28
32
  """
29
33
 
30
34
  def continuum(self, energy):
31
- alpha = hk.get_parameter("alpha", [], init=HaikuConstant(1.3))
32
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1e-4))
35
+ alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
36
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
33
37
 
34
38
  return norm * energy ** (-alpha)
35
39
 
@@ -45,12 +49,12 @@ class AdditiveConstant(AdditiveComponent):
45
49
  """
46
50
 
47
51
  def continuum(self, energy):
48
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
52
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
49
53
 
50
54
  return norm * jnp.ones_like(energy)
51
55
 
52
56
  def primitive(self, energy):
53
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
57
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
54
58
 
55
59
  return norm * energy
56
60
 
@@ -68,9 +72,9 @@ class Lorentz(AdditiveComponent):
68
72
  """
69
73
 
70
74
  def continuum(self, energy):
71
- line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
72
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
73
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
75
+ line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
76
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
77
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
74
78
 
75
79
  return norm * sigma / (2 * jnp.pi) / ((energy - line_energy) ** 2 + (sigma / 2) ** 2)
76
80
 
@@ -93,9 +97,9 @@ class Logparabola(AdditiveComponent):
93
97
 
94
98
  # TODO : conform with xspec definition
95
99
  def continuum(self, energy):
96
- a = hk.get_parameter("a", [], init=HaikuConstant(11 / 3))
97
- b = hk.get_parameter("b", [], init=HaikuConstant(0.2))
98
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
100
+ a = hk.get_parameter("a", [], float, init=HaikuConstant(11 / 3))
101
+ b = hk.get_parameter("b", [], float, init=HaikuConstant(0.2))
102
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
99
103
 
100
104
  return norm * energy ** (-(a + b * jnp.log(energy)))
101
105
 
@@ -113,8 +117,8 @@ class Blackbody(AdditiveComponent):
113
117
  """
114
118
 
115
119
  def continuum(self, energy):
116
- kT = hk.get_parameter("kT", [], init=HaikuConstant(11 / 3))
117
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
120
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
121
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
118
122
 
119
123
  return norm * 8.0525 * energy**2 / ((kT**4) * (jnp.exp(energy / kT) - 1))
120
124
 
@@ -132,8 +136,8 @@ class Blackbodyrad(AdditiveComponent):
132
136
  """
133
137
 
134
138
  def continuum(self, energy):
135
- kT = hk.get_parameter("kT", [], init=HaikuConstant(11 / 3))
136
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
139
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(11 / 3))
140
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
137
141
 
138
142
  return norm * 1.0344e-3 * energy**2 / (jnp.exp(energy / kT) - 1)
139
143
 
@@ -152,9 +156,9 @@ class Gauss(AdditiveComponent):
152
156
  """
153
157
 
154
158
  def continuum(self, energy) -> (jax.Array, jax.Array):
155
- line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
156
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
157
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
159
+ line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
160
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
161
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
158
162
 
159
163
  return norm * jsp.stats.norm.pdf(energy, loc=line_energy, scale=sigma)
160
164
 
@@ -200,9 +204,9 @@ class APEC(AdditiveComponent):
200
204
  )
201
205
 
202
206
  def mono_fine_structure(self, e_low, e_high) -> (jax.Array, jax.Array):
203
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
204
- kT = hk.get_parameter("kT", [], init=HaikuConstant(1))
205
- Z = hk.get_parameter("Z", [], init=HaikuConstant(1))
207
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
208
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
209
+ Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
206
210
 
207
211
  idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1
208
212
 
@@ -231,9 +235,9 @@ class APEC(AdditiveComponent):
231
235
 
232
236
  @partial(jnp.vectorize, excluded=(0,))
233
237
  def continuum(self, energy):
234
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
235
- kT = hk.get_parameter("kT", [], init=HaikuConstant(1))
236
- Z = hk.get_parameter("Z", [], init=HaikuConstant(1))
238
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
239
+ kT = hk.get_parameter("kT", [], float, init=HaikuConstant(1))
240
+ Z = hk.get_parameter("Z", [], float, init=HaikuConstant(1))
237
241
 
238
242
  idx = jnp.searchsorted(self.kT_ref, kT, side="left") - 1 # index of left value
239
243
 
@@ -271,9 +275,9 @@ class Cutoffpl(AdditiveComponent):
271
275
  """
272
276
 
273
277
  def continuum(self, energy):
274
- alpha = hk.get_parameter("alpha", [], init=HaikuConstant(1.3))
275
- beta = hk.get_parameter("beta", [], init=HaikuConstant(15))
276
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1e-4))
278
+ alpha = hk.get_parameter("alpha", [], float, init=HaikuConstant(1.3))
279
+ beta = hk.get_parameter("beta", [], float, init=HaikuConstant(15))
280
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1e-4))
277
281
 
278
282
  return norm * energy ** (-alpha) * jnp.exp(-energy / beta)
279
283
 
@@ -298,9 +302,9 @@ class Diskpbb(AdditiveComponent):
298
302
  """
299
303
 
300
304
  def continuum(self, energy):
301
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
302
- p = hk.get_parameter("p", [], init=HaikuConstant(0.75))
303
- tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
305
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
306
+ p = hk.get_parameter("p", [], float, init=HaikuConstant(0.75))
307
+ tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
304
308
 
305
309
  # Tout is set to 0 as it is evaluated at R=infinity
306
310
  def integrand(kT, energy):
@@ -326,8 +330,8 @@ class Diskbb(AdditiveComponent):
326
330
  def continuum(self, energy):
327
331
  p = 0.75
328
332
  tout = 0.0
329
- tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
330
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
333
+ tin = hk.get_parameter("Tin", [], float, init=HaikuConstant(1))
334
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
331
335
 
332
336
  # Tout is set to 0 as it is evaluated at R=infinity
333
337
  def integrand(kT, e, tin, p):
@@ -360,9 +364,9 @@ class Agauss(AdditiveComponent):
360
364
 
361
365
  def continuum(self, energy) -> (jax.Array, jax.Array):
362
366
  hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
363
- line_wavelength = hk.get_parameter("Lambda_l", [], init=HaikuConstant(hc))
364
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(0.001))
365
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
367
+ line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
368
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
369
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
366
370
 
367
371
  return norm * jsp.stats.norm.pdf(hc / energy, loc=line_wavelength, scale=sigma)
368
372
 
@@ -384,10 +388,10 @@ class Zagauss(AdditiveComponent):
384
388
 
385
389
  def continuum(self, energy) -> (jax.Array, jax.Array):
386
390
  hc = (astropy.constants.h * astropy.constants.c).to(u.angstrom * u.keV).value
387
- line_wavelength = hk.get_parameter("Lambda_l", [], init=HaikuConstant(hc))
388
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(0.001))
389
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
390
- redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
391
+ line_wavelength = hk.get_parameter("Lambda_l", [], float, init=HaikuConstant(hc))
392
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(0.001))
393
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
394
+ redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
391
395
 
392
396
  return (
393
397
  norm
@@ -411,11 +415,132 @@ class Zgauss(AdditiveComponent):
411
415
  """
412
416
 
413
417
  def continuum(self, energy) -> (jax.Array, jax.Array):
414
- line_energy = hk.get_parameter("E_l", [], init=HaikuConstant(1))
415
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
416
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
417
- redshift = hk.get_parameter("redshift", [], init=HaikuConstant(0))
418
+ line_energy = hk.get_parameter("E_l", [], float, init=HaikuConstant(1))
419
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
420
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1))
421
+ redshift = hk.get_parameter("redshift", [], float, init=HaikuConstant(0))
418
422
 
419
423
  return (norm / (1 + redshift)) * jsp.stats.norm.pdf(
420
424
  energy * (1 + redshift), loc=line_energy, scale=sigma
421
425
  )
426
+
427
+
428
+ class NSatmos(AdditiveComponent):
429
+ r"""
430
+ A neutron star atmosphere model based on the `NSATMOS` model from `XSPEC`. See [this page](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/node205.html)
431
+
432
+ !!! warning
433
+ The boundary case of $R_{\text{NS}} < 1.125 R_{\text{S}}$ is handled with a null flux instead of a constant value as in `XSPEC`.
434
+
435
+ ??? abstract "Parameters"
436
+ * $T_{eff}$ : Effective temperature at the surface in K (No redshift applied)
437
+ * $M_{ns}$ : Mass of the NS in solar masses
438
+ * $R_∞$ : Radius at infinity (modulated by gravitational effects) in km
439
+ * $D$ : Distance to the neutron star in kpc
440
+ * norm : fraction of the neutron star surface emitting
441
+ """
442
+
443
+ def __init__(self, *args, **kwargs):
444
+ super().__init__(*args, **kwargs)
445
+ entry_table = Table.read(table_manager.fetch("nsatmosdata.fits"), 1)
446
+
447
+ # Extract the table values. All this code could be summarize in two lines if we reformat the nsatmosdata.fits table
448
+ self.tab_temperature = np.asarray(entry_table["TEMP"][0], dtype=float) # Logarithmic value
449
+ self.tab_gravity = np.asarray(entry_table["GRAVITY"][0], dtype=float) # Logarithmic value
450
+ self.tab_mucrit = np.asarray(entry_table["MUCRIT"][0], dtype=float)
451
+ self.tab_energy = np.asarray(entry_table["ENERGY"][0], dtype=float)
452
+ self.tab_flux_flat = Table.read(table_manager.fetch("nsatmosdata.fits"), 2)["FLUX"]
453
+
454
+ tab_flux = np.empty(
455
+ (
456
+ self.tab_temperature.size,
457
+ self.tab_gravity.size,
458
+ self.tab_mucrit.size,
459
+ self.tab_energy.size,
460
+ )
461
+ )
462
+
463
+ for i in range(len(self.tab_temperature)):
464
+ for j in range(len(self.tab_gravity)):
465
+ for k in range(len(self.tab_mucrit)):
466
+ tab_flux[i, j, k] = np.array(
467
+ self.tab_flux_flat[
468
+ i * len(self.tab_gravity) * len(self.tab_mucrit)
469
+ + j * len(self.tab_mucrit)
470
+ + k
471
+ ]
472
+ )
473
+
474
+ self.tab_flux = np.asarray(tab_flux, dtype=float)
475
+
476
+ def interp_flux_func(self, temperature_log, gravity_log, mu):
477
+ # Interpolate the tables to get the flux on the tabulated energy grid
478
+
479
+ return interpax.interp3d(
480
+ 10.0**temperature_log,
481
+ 10.0**gravity_log,
482
+ mu,
483
+ 10.0**self.tab_temperature,
484
+ 10.0**self.tab_gravity,
485
+ self.tab_mucrit,
486
+ self.tab_flux,
487
+ method="linear",
488
+ )
489
+
490
+ def continuum(self, energy):
491
+ temp_log = hk.get_parameter(
492
+ "Tinf", [], float, init=HaikuConstant(6.0)
493
+ ) # log10 of temperature in Kelvin
494
+
495
+ # 'Tinf': temp_log, # 5 to 6.5
496
+ # 'M': mass, # 0.5 to 3
497
+ # 'Rns': radius, # 5 to 30
498
+
499
+ mass = hk.get_parameter("M", [], float, init=HaikuConstant(1.4))
500
+ radius = hk.get_parameter("Rns", [], float, init=HaikuConstant(10.0))
501
+ distance = hk.get_parameter("dns", [], float, init=HaikuConstant(10.0))
502
+ norm = hk.get_parameter("norm", [], float, init=HaikuConstant(1.0))
503
+
504
+ # Derive parameters usable to retrive value in the flux table
505
+ Rcgs = 1e5 * radius # Radius in cgs
506
+ r_schwarzschild = 2.95e5 * mass # Schwarzschild radius in cgs
507
+ r_normalized = Rcgs / r_schwarzschild # Ratio of the radius to the Schwarzschild radius
508
+ r_over_Dsql = 2 * jnp.log10(
509
+ Rcgs / (3.09e21 * distance)
510
+ ) # Log( (R/D)**2 ), 3.09e21 constant transforms radius in cgs to kpc
511
+ zred1 = 1 / jnp.sqrt(1 - (1 / r_normalized)) # Gravitational redshift
512
+ gravity = (6.67e-8 * 1.99e33 * mass / Rcgs**2) * zred1 # Gravity field g in cgs
513
+ gravity_log = jnp.log10(
514
+ gravity
515
+ ) # Log gravity because this is the format given in the table
516
+
517
+ # Not sure about mu yet, but it is linked to causality
518
+ cmu = jnp.where(
519
+ r_normalized < 1.5, jnp.sqrt(1.0 - 6.75 / r_normalized**2 + 6.75 / r_normalized**3), 0.0
520
+ )
521
+
522
+ # Interpolate the flux table to get the flux at the surface
523
+
524
+ flux = jax.jit(self.interp_flux_func)(temp_log, gravity_log, cmu)
525
+
526
+ # Rescale the photon energies and fluxes back to the correct local temperature
527
+ Tfactor = 10.0 ** (temp_log - 6.0)
528
+ fluxshift = 3.0 * (temp_log - 6.0)
529
+ E = self.tab_energy * Tfactor
530
+ flux += fluxshift
531
+
532
+ # Rescale applying redshift
533
+ fluxshift = -jnp.log10(zred1)
534
+ E = E / zred1
535
+ flux += fluxshift
536
+
537
+ # Convert to counts/keV (which corresponds to dividing by1.602e-9*EkeV)
538
+ # Multiply by the area of the star, and calculate the count rate at the observer
539
+ flux += r_over_Dsql
540
+ counts = 10.0 ** (flux - jnp.log10(1.602e-9) - jnp.log10(E))
541
+
542
+ true_flux = norm * jnp.exp(
543
+ interpax.interp1d(jnp.log(energy), jnp.log(E), jnp.log(counts), method="linear")
544
+ )
545
+
546
+ return jax.lax.select(r_normalized < 1.125, jnp.zeros_like(true_flux), true_flux)
@@ -37,9 +37,9 @@ class Expfac(MultiplicativeComponent):
37
37
  """
38
38
 
39
39
  def continuum(self, energy):
40
- amplitude = hk.get_parameter("A", [], init=HaikuConstant(1))
41
- factor = hk.get_parameter("f", [], init=HaikuConstant(1))
42
- pivot = hk.get_parameter("E_c", [], init=HaikuConstant(1))
40
+ amplitude = hk.get_parameter("A", [], float, init=HaikuConstant(1))
41
+ factor = hk.get_parameter("f", [], float, init=HaikuConstant(1))
42
+ pivot = hk.get_parameter("E_c", [], float, init=HaikuConstant(1))
43
43
 
44
44
  return jnp.where(energy >= pivot, 1.0 + amplitude * jnp.exp(-factor * energy), 1.0)
45
45
 
@@ -63,12 +63,14 @@ class Tbabs(MultiplicativeComponent):
63
63
 
64
64
  """
65
65
 
66
- table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
67
- energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
68
- sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
66
+ def __init__(self, *args, **kwargs):
67
+ super().__init__(*args, **kwargs)
68
+ table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
69
+ self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
70
+ self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
69
71
 
70
72
  def continuum(self, energy):
71
- nh = hk.get_parameter("N_H", [], init=HaikuConstant(1))
73
+ nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
72
74
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
73
75
 
74
76
  return jnp.exp(-nh * sigma)
@@ -84,12 +86,14 @@ class Phabs(MultiplicativeComponent):
84
86
 
85
87
  """
86
88
 
87
- table = Table.read(table_manager.fetch("xsect_phabs_aspl.fits"))
88
- energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
89
- sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
89
+ def __init__(self, *args, **kwargs):
90
+ super().__init__(*args, **kwargs)
91
+ table = Table.read(table_manager.fetch("xsect_phabs_aspl.fits"))
92
+ self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
93
+ self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
90
94
 
91
95
  def continuum(self, energy):
92
- nh = hk.get_parameter("N_H", [], init=HaikuConstant(1))
96
+ nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
93
97
  sigma = jnp.interp(energy, self.energy, self.sigma, left=jnp.inf, right=0.0)
94
98
 
95
99
  return jnp.exp(-nh * sigma)
@@ -105,12 +109,14 @@ class Wabs(MultiplicativeComponent):
105
109
 
106
110
  """
107
111
 
108
- table = Table.read(table_manager.fetch("xsect_wabs_angr.fits"))
109
- energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
110
- sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
112
+ def __init__(self, *args, **kwargs):
113
+ super().__init__(*args, **kwargs)
114
+ table = Table.read(table_manager.fetch("xsect_wabs_angr.fits"))
115
+ self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
116
+ self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
111
117
 
112
118
  def continuum(self, energy):
113
- nh = hk.get_parameter("N_H", [], init=HaikuConstant(1))
119
+ nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
114
120
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
115
121
 
116
122
  return jnp.exp(-nh * sigma)
@@ -136,9 +142,9 @@ class Gabs(MultiplicativeComponent):
136
142
  """
137
143
 
138
144
  def continuum(self, energy):
139
- tau = hk.get_parameter("tau", [], init=HaikuConstant(1))
140
- sigma = hk.get_parameter("sigma", [], init=HaikuConstant(1))
141
- center = hk.get_parameter("E_0", [], init=HaikuConstant(1))
145
+ tau = hk.get_parameter("tau", [], float, init=HaikuConstant(1))
146
+ sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
147
+ center = hk.get_parameter("E_0", [], float, init=HaikuConstant(1))
142
148
 
143
149
  return jnp.exp(
144
150
  -tau / (jnp.sqrt(2 * jnp.pi) * sigma) * jnp.exp(-0.5 * ((energy - center) / sigma) ** 2)
@@ -160,8 +166,8 @@ class Highecut(MultiplicativeComponent):
160
166
  """
161
167
 
162
168
  def continuum(self, energy):
163
- cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
164
- folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
169
+ cutoff = hk.get_parameter("E_c", [], float, init=HaikuConstant(1))
170
+ folding = hk.get_parameter("E_f", [], float, init=HaikuConstant(1))
165
171
 
166
172
  return jnp.where(energy <= cutoff, 1.0, jnp.exp((cutoff - energy) / folding))
167
173
 
@@ -182,9 +188,9 @@ class Zedge(MultiplicativeComponent):
182
188
  """
183
189
 
184
190
  def continuum(self, energy):
185
- E_c = hk.get_parameter("E_c", [], init=HaikuConstant(1))
186
- D = hk.get_parameter("D", [], init=HaikuConstant(1))
187
- z = hk.get_parameter("z", [], init=HaikuConstant(0))
191
+ E_c = hk.get_parameter("E_c", [], float, init=HaikuConstant(1))
192
+ D = hk.get_parameter("D", [], float, init=HaikuConstant(1))
193
+ z = hk.get_parameter("z", [], float, init=HaikuConstant(0))
188
194
 
189
195
  return jnp.where(energy <= E_c, 1.0, jnp.exp(-D * (energy * (1 + z) / E_c) ** 3))
190
196
 
@@ -207,13 +213,34 @@ class Tbpcf(MultiplicativeComponent):
207
213
 
208
214
  """
209
215
 
210
- table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
211
- energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
212
- sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
216
+ def __init__(self, *args, **kwargs):
217
+ super().__init__(*args, **kwargs)
218
+ table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
219
+ self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
220
+ self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
213
221
 
214
222
  def continuum(self, energy):
215
- nh = hk.get_parameter("N_H", [], init=HaikuConstant(1))
216
- f = hk.get_parameter("f", [], init=HaikuConstant(0.2))
223
+ nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
224
+ f = hk.get_parameter("f", [], float, init=HaikuConstant(0.2))
217
225
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
218
226
 
219
227
  return f * jnp.exp(-nh * sigma) + (1 - f)
228
+
229
+ class FDcut(MultiplicativeComponent):
230
+ r"""
231
+ A Fermi-Dirac cutoff model.
232
+
233
+ $$
234
+ \mathcal{M}(E) = \left( 1 + \exp \left( \frac{E - E_c}{E_f} \right) \right)^{-1}
235
+ $$
236
+
237
+ ??? abstract "Parameters"
238
+ * $E_c$ : Cutoff energy $\left[\text{keV}\right]$
239
+ * $E_f$ : Folding energy $\left[\text{keV}\right]$
240
+ """
241
+
242
+ def continuum(self, energy):
243
+ cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
244
+ folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
245
+
246
+ return (1 + jnp.exp((energy - cutoff)/folding)) ** -1
@@ -11,5 +11,18 @@ table_manager = pooch.create(
11
11
  "xsect_tbabs_wilm.fits": "sha256:3cf45e45c9d671c4c4fc128314b7c3a68b30f096eede6b3eb08bf55224a44935",
12
12
  "xsect_phabs_aspl.fits": "sha256:3eaffba2a62e3a611e0a4e1ff4a57342d7d576f023d7bbb632710dc75b9a5019",
13
13
  "xsect_wabs_angr.fits": "sha256:9b3073a477a30b52e207f2c4bf79afc6ae19abba8f207190ac4c697024f74073",
14
+ "nsatmosdata.fits": "sha256:faca712f87710ecb866b4ab61be593a6813517c44f6e8e92d689b38d42e1b6dc",
15
+ "example_data/NGC7793_ULX4/MOS2background_spectrum.fits": "sha256:5387923be0bf39229f4390dd5e85095a3d534b43a69d6d3179b832ebb366d173",
16
+ "example_data/NGC7793_ULX4/MOS1background_spectrum.fits": "sha256:265fd7465fb1a355f915d9902443ba2fd2be9aface04723056a8376971e3cf14",
17
+ "example_data/NGC7793_ULX4/MOS2.rmf": "sha256:b6af00603dece33dcda35d093451c947059af2e1e45c31c5a0ffa223b7fb693d",
18
+ "example_data/NGC7793_ULX4/PN.arf": "sha256:0ee897a63b6de80589c2da758d7477c54ba601b788bf32d4d16bbffa839acb73",
19
+ "example_data/NGC7793_ULX4/MOS1.rmf": "sha256:2d1138d22c31c5398a4eed1170b0b88b07350ecfec7a7aef4f550871cb4309ae",
20
+ "example_data/NGC7793_ULX4/PN_spectrum_grp20.fits": "sha256:a985e06076bf060d3a5331f20413afa9208a8d771fa6c671e8918a1860577c90",
21
+ "example_data/NGC7793_ULX4/MOS2_spectrum_grp.fits": "sha256:dccc7eda9d3d2e4aac2af4ca13d41ab4acc621265004d50a1586a187f7a04ffc",
22
+ "example_data/NGC7793_ULX4/MOS1_spectrum_grp.fits": "sha256:7e1ff664545bab4fdce1ef94768715b4d87a39b252b61e070e71427e5c8692ac",
23
+ "example_data/NGC7793_ULX4/MOS1.arf": "sha256:9017ada6a391d46f9b569b8d0338fbabb62a5397e7c29eb0a16e4e02d4868159",
24
+ "example_data/NGC7793_ULX4/PN.rmf": "sha256:91ba9ef82da8b9f73e6a799dfe097b87c68a7020ac6c5aa0dcd4067bf9cb4287",
25
+ "example_data/NGC7793_ULX4/MOS2.arf": "sha256:a126ff5a95a5f4bb93ed846944cf411d6e1c448626cb73d347e33324663d8b3f",
26
+ "example_data/NGC7793_ULX4/PNbackground_spectrum.fits": "sha256:55e017e0c19b324245fef049dff2a7a2e49b9a391667ca9c4f667c4f683b1f49",
14
27
  },
15
28
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxspec
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  License: MIT
6
6
  Author: sdupourque
@@ -10,21 +10,22 @@ Classifier: License :: OSI Approved :: MIT License
10
10
  Classifier: Programming Language :: Python :: 3
11
11
  Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
- Requires-Dist: arviz (>=0.17.1,<0.19.0)
13
+ Requires-Dist: arviz (>=0.17.1,<0.20.0)
14
14
  Requires-Dist: astropy (>=6.0.0,<7.0.0)
15
15
  Requires-Dist: chainconsumer (>=1.0.0,<2.0.0)
16
16
  Requires-Dist: cmasher (>=1.6.3,<2.0.0)
17
17
  Requires-Dist: dm-haiku (>=0.0.11,<0.0.13)
18
18
  Requires-Dist: gpjax (>=0.8.0,<0.9.0)
19
- Requires-Dist: jax (>=0.4.29,<0.5.0)
20
- Requires-Dist: jaxlib (>=0.4.29,<0.5.0)
19
+ Requires-Dist: interpax (>=0.3.3,<0.4.0)
20
+ Requires-Dist: jax (>=0.4.30,<0.5.0)
21
+ Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
21
22
  Requires-Dist: jaxns (>=2.5.1,<3.0.0)
22
23
  Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
23
24
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
24
- Requires-Dist: mendeleev (>=0.15,<0.17)
25
+ Requires-Dist: mendeleev (>=0.15,<0.18)
25
26
  Requires-Dist: mkdocstrings (>=0.24,<0.26)
26
27
  Requires-Dist: networkx (>=3.1,<4.0)
27
- Requires-Dist: numpy (>=1.26.1,<2.0.0)
28
+ Requires-Dist: numpy (<2.0.0)
28
29
  Requires-Dist: numpyro (>=0.15.0,<0.16.0)
29
30
  Requires-Dist: optimistix (>=0.0.7,<0.0.8)
30
31
  Requires-Dist: pandas (>=2.2.0,<3.0.0)
@@ -1,46 +1,42 @@
1
1
  jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
2
2
  jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
4
- jaxspec/analysis/results.py,sha256=58IM_HS3q8xgW0espGgh11eBdIJYa-m2XccW_-pO2to,24495
4
+ jaxspec/analysis/results.py,sha256=mTSPPzKVsZ-yaz-BZvGRw195S9542VZb0kR1vv4np6Y,26541
5
5
  jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
6
6
  jaxspec/data/example_data/MOS1.arf,sha256=kBetpqOR1G-bVpuNAzj7q7YqU5fnwp6woW5OAtSGgVk,34560
7
- jaxspec/data/example_data/MOS1.pha,sha256=fh_2ZFRbq0_c4e-UdocVtNh6ObJSth4HDnFCflyGkqw,83520
8
7
  jaxspec/data/example_data/MOS1.rmf,sha256=LRE40iwxxTmKTu0RcLC4iwc1Ds_senrvT1UIcctDCa4,14846400
9
8
  jaxspec/data/example_data/MOS1_spectrum_grp.fits,sha256=fh_2ZFRbq0_c4e-UdocVtNh6ObJSth4HDnFCflyGkqw,83520
10
9
  jaxspec/data/example_data/MOS1background_spectrum.fits,sha256=Jl_XRl-xo1X5FdmQJEO6L9K-mvrOBHIwVqg3aXHjzxQ,69120
11
10
  jaxspec/data/example_data/MOS2.arf,sha256=oSb_WpWl9LuT7YRpRM9BHW4cRIYmy3PTR-MzJGY9iz8,34560
12
- jaxspec/data/example_data/MOS2.pha,sha256=3Mx-2p09LkqsKvTKE9QatKzGISZQBNUKFYahh_egT_w,89280
13
11
  jaxspec/data/example_data/MOS2.rmf,sha256=tq8AYD3s4z3No10JNFHJRwWa8uHkXDHFoP-iI7f7aT0,14783040
14
12
  jaxspec/data/example_data/MOS2_spectrum_grp.fits,sha256=3Mx-2p09LkqsKvTKE9QatKzGISZQBNUKFYahh_egT_w,89280
15
13
  jaxspec/data/example_data/MOS2background_spectrum.fits,sha256=U4eSO-C_OSKfQ5DdXoUJWj1TS0OmnW0xebgy67Nm0XM,77760
16
14
  jaxspec/data/example_data/PN.arf,sha256=DuiXpjtt6AWJwtp1jXR3xUumAbeIvzLU0Wu_-oOay3M,31680
17
- jaxspec/data/example_data/PN.pha,sha256=qYXgYHa_Bg06UzHyBBOvqSCKjXcfpsZx6JGKGGBXfJA,138240
18
15
  jaxspec/data/example_data/PN.rmf,sha256=kbqe-C2oufc-anmd_gl7h8aKcCCsbFqg3NQGe_nLQoc,3962880
19
16
  jaxspec/data/example_data/PN_spectrum_grp20.fits,sha256=qYXgYHa_Bg06UzHyBBOvqSCKjXcfpsZx6JGKGGBXfJA,138240
20
17
  jaxspec/data/example_data/PNbackground_spectrum.fits,sha256=VeAX4MGbMkJF_vBJ3_KnouSbmjkWZ8qcT2Z8T2g7H0k,120960
21
- jaxspec/data/example_data/fakeit.pha,sha256=IhkeWkE-b3ELECd_Uasjo9h3cXgcjCYH20wDpXJ8LMk,60480
22
18
  jaxspec/data/grouping.py,sha256=hhgBt-voiH0DDSyePacaIGsaMnrYbJM_-ZeU66keC7I,622
23
19
  jaxspec/data/instrument.py,sha256=0pSf1p82g7syDMmKm13eVbYih-Veiq5DnwsyZe6_b4g,3890
24
20
  jaxspec/data/obsconf.py,sha256=0X9jR-pV-Pk4-EVuUdlVWgl_gBx8ZurVkRNrfKQWdC4,8663
25
21
  jaxspec/data/observation.py,sha256=1UnFu5ihZp9z-vP_I7tsFY8jhhIJunv46JyuE-acrg0,6394
26
22
  jaxspec/data/ogip.py,sha256=sv9p00qHS5pzw61pzWyyF0nV-E-RXySdSFK2tUavokA,9545
27
- jaxspec/data/util.py,sha256=TZg_zrH2qk3LiPmn7yw5nMN-XT3z23R9pq8HMYBB_uE,8231
28
- jaxspec/fit.py,sha256=yfP1INuHYfjpXvZfmkgJ84TbeGWuFg1nvkXduMfxgyk,22057
23
+ jaxspec/data/util.py,sha256=ycLPVE-cjn6VpUWYlBU1BGfw73ANXIBilyVAUOYOSj0,9540
24
+ jaxspec/fit.py,sha256=4rJ8Zcv-CGwjAEKQ5r6bKD4Abw9OMcZd__rns6S4fto,21375
29
25
  jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
26
  jaxspec/model/_additive/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
27
  jaxspec/model/_additive/apec.py,sha256=r7CQqscAgR0BXC_AJqF6B7CPq3Byoo65Z-h9XgACZeU,12460
32
28
  jaxspec/model/_additive/apec_loaders.py,sha256=jkUoH0ezeYdaNw3oV10V0L-jt848SKp2thanLWLWp9k,2412
33
29
  jaxspec/model/abc.py,sha256=SWjKOOsqU5UJsVy63Tt9dDq8H2eTIbvK2C9iqgiR0cY,19817
34
- jaxspec/model/additive.py,sha256=ayGkL0ftmK-MVndYoHorPZVpNopi1MyyEtymLlBRl-o,16885
30
+ jaxspec/model/additive.py,sha256=CT2K2DVVeHKN1tee9-J3MYdEPqEOolLB2E7HU-RJKZw,22485
35
31
  jaxspec/model/background.py,sha256=QSFFiuyUEvuzXBx3QfkvVneUR8KKEP-VaANEVXcavDE,7865
36
32
  jaxspec/model/list.py,sha256=0RPAoscVz_zM1CWdx_Gd5wfrQWV5Nv4Kd4bSXu2ayUA,860
37
- jaxspec/model/multiplicative.py,sha256=8S7_agz32fdX-qxYnkL7fGXW2CnuwRGBZ5pYV4-b_5k,7194
33
+ jaxspec/model/multiplicative.py,sha256=TG3PCgS7oCuHwJ4TM4whw6pz318oo9MVvjSs4sQZVPc,8300
38
34
  jaxspec/util/__init__.py,sha256=vKurfp7p2hxHptJjXhXqFAXAikAGXAqISMJUqPeiGTw,1259
39
35
  jaxspec/util/abundance.py,sha256=fsC313taIlGzQsZNwbYsJupDWm7ZbqzGhY66Ku394Mw,8546
40
36
  jaxspec/util/integrate.py,sha256=_Ax_knpC7d4et2-QFkOUzVtNeQLX1-cwLvm-FRBxYcw,4505
41
- jaxspec/util/online_storage.py,sha256=cJGHsgPh3CukRwwlOcxV9eGvrainDOIP5AQP60tWLN4,818
37
+ jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
42
38
  jaxspec/util/typing.py,sha256=qwZMKHivZlozoo0ESsiaQNkG99Dh3PE2Z-5aOQD9zc0,1650
43
- jaxspec-0.0.7.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
44
- jaxspec-0.0.7.dist-info/METADATA,sha256=MX7-CIwPDBVjZtjRP3BMQnm0gKEcN5Z91yYW4n5PvXA,3375
45
- jaxspec-0.0.7.dist-info/WHEEL,sha256=d2fvjOD7sXsVzChCqf0Ty0JbHKBaLYwDbGQDwQTnJ50,88
46
- jaxspec-0.0.7.dist-info/RECORD,,
39
+ jaxspec-0.0.8.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
40
+ jaxspec-0.0.8.dist-info/METADATA,sha256=vRPdPEBjgjTjvYpff1w-O7yTlcdAaoFF_jNTDst7P-M,3407
41
+ jaxspec-0.0.8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
42
+ jaxspec-0.0.8.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.7.0
2
+ Generator: poetry-core 1.9.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any