jaxspec 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +83 -32
- jaxspec/data/__init__.py +2 -0
- jaxspec/data/instrument.py +2 -1
- jaxspec/data/util.py +2 -2
- jaxspec/fit/_fitter.py +18 -5
- jaxspec/model/abc.py +4 -3
- jaxspec/model/additive.py +17 -21
- jaxspec/model/multiplicative.py +54 -3
- jaxspec/util/online_storage.py +1 -0
- {jaxspec-0.3.2.dist-info → jaxspec-0.3.4.dist-info}/METADATA +4 -8
- {jaxspec-0.3.2.dist-info → jaxspec-0.3.4.dist-info}/RECORD +14 -14
- {jaxspec-0.3.2.dist-info → jaxspec-0.3.4.dist-info}/WHEEL +1 -1
- {jaxspec-0.3.2.dist-info → jaxspec-0.3.4.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.3.2.dist-info → jaxspec-0.3.4.dist-info}/licenses/LICENSE.md +0 -0
jaxspec/analysis/results.py
CHANGED
|
@@ -77,8 +77,9 @@ class FitResult:
|
|
|
77
77
|
r"""
|
|
78
78
|
Convergence of the chain as computed by the $\hat{R}$ statistic.
|
|
79
79
|
"""
|
|
80
|
+
rhat = az.rhat(self.inference_data)
|
|
80
81
|
|
|
81
|
-
return
|
|
82
|
+
return bool((rhat.to_array() < 1.01).all())
|
|
82
83
|
|
|
83
84
|
def _ppc_folded_branches(self, obs_id):
|
|
84
85
|
obs = self.obsconfs[obs_id]
|
|
@@ -167,6 +168,8 @@ class FitResult:
|
|
|
167
168
|
e_max: float,
|
|
168
169
|
unit: Unit = u.photon / u.cm**2 / u.s,
|
|
169
170
|
register: bool = False,
|
|
171
|
+
n_points: int = 5,
|
|
172
|
+
n_grid: int = 1_000,
|
|
170
173
|
) -> ArrayLike:
|
|
171
174
|
"""
|
|
172
175
|
Compute the unfolded photon flux in a given energy band. The flux is then added to
|
|
@@ -177,29 +180,40 @@ class FitResult:
|
|
|
177
180
|
e_max: The upper bound of the energy band in observer frame.
|
|
178
181
|
unit: The unit of the photon flux.
|
|
179
182
|
register: Whether to register the flux with the other posterior parameters.
|
|
183
|
+
n_points: The number of points per bin to use for computing the unfolded spectrum.
|
|
184
|
+
n_grid: The number of grid points to use for computing the unfolded spectrum.
|
|
180
185
|
|
|
181
186
|
!!! warning
|
|
182
187
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
183
188
|
[issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
|
|
184
189
|
"""
|
|
185
190
|
|
|
191
|
+
energy_grid = np.linspace(e_min, e_max, n_grid)
|
|
192
|
+
|
|
186
193
|
@jax.jit
|
|
187
194
|
@jnp.vectorize
|
|
188
195
|
def vectorized_flux(*pars):
|
|
189
196
|
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
190
197
|
return self.model.photon_flux(
|
|
191
|
-
parameters_pytree,
|
|
192
|
-
)
|
|
198
|
+
parameters_pytree, energy_grid[:-1], energy_grid[1:], n_points=n_points
|
|
199
|
+
)
|
|
193
200
|
|
|
194
201
|
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
195
|
-
flux = vectorized_flux(*flat_tree)
|
|
196
|
-
conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
|
|
197
|
-
value = flux * conversion_factor
|
|
202
|
+
flux = vectorized_flux(*flat_tree).sum(axis=-1) # Sum over all bins
|
|
203
|
+
conversion_factor = float((u.photon / u.cm**2 / u.s).to(unit))
|
|
204
|
+
value = np.asarray(flux * conversion_factor)
|
|
198
205
|
|
|
199
206
|
if register:
|
|
200
|
-
self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
201
|
-
|
|
202
|
-
|
|
207
|
+
self.inference_data.posterior[f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
208
|
+
xr.DataArray(
|
|
209
|
+
value,
|
|
210
|
+
dims=self.inference_data.posterior.dims,
|
|
211
|
+
coords={
|
|
212
|
+
"chain": self.inference_data.posterior.chain,
|
|
213
|
+
"draw": self.inference_data.posterior.draw,
|
|
214
|
+
},
|
|
215
|
+
name=f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}",
|
|
216
|
+
)
|
|
203
217
|
)
|
|
204
218
|
|
|
205
219
|
return value
|
|
@@ -210,6 +224,8 @@ class FitResult:
|
|
|
210
224
|
e_max: float,
|
|
211
225
|
unit: Unit = u.erg / u.cm**2 / u.s,
|
|
212
226
|
register: bool = False,
|
|
227
|
+
n_points: int = 5,
|
|
228
|
+
n_grid: int = 1_000,
|
|
213
229
|
) -> ArrayLike:
|
|
214
230
|
"""
|
|
215
231
|
Compute the unfolded energy flux in a given energy band. The flux is then added to
|
|
@@ -220,29 +236,40 @@ class FitResult:
|
|
|
220
236
|
e_max: The upper bound of the energy band in observer frame.
|
|
221
237
|
unit: The unit of the energy flux.
|
|
222
238
|
register: Whether to register the flux with the other posterior parameters.
|
|
239
|
+
n_points: The number of points per bin to use for computing the unfolded spectrum.
|
|
240
|
+
n_grid: The number of grid points to use for computing the unfolded spectrum.
|
|
223
241
|
|
|
224
242
|
!!! warning
|
|
225
243
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
226
244
|
[issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
|
|
227
245
|
"""
|
|
228
246
|
|
|
247
|
+
energy_grid = np.linspace(e_min, e_max, n_grid)
|
|
248
|
+
|
|
229
249
|
@jax.jit
|
|
230
250
|
@jnp.vectorize
|
|
231
251
|
def vectorized_flux(*pars):
|
|
232
252
|
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
233
253
|
return self.model.energy_flux(
|
|
234
|
-
parameters_pytree,
|
|
235
|
-
)
|
|
254
|
+
parameters_pytree, energy_grid[:-1], energy_grid[1:], n_points=n_points
|
|
255
|
+
)
|
|
236
256
|
|
|
237
257
|
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
238
|
-
flux = vectorized_flux(*flat_tree)
|
|
239
|
-
conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
|
|
240
|
-
value = flux * conversion_factor
|
|
241
|
-
|
|
258
|
+
flux = vectorized_flux(*flat_tree).sum(axis=-1) # Sum over all bins
|
|
259
|
+
conversion_factor = float((u.keV / u.cm**2 / u.s).to(unit))
|
|
260
|
+
value = np.asarray(flux * conversion_factor)
|
|
261
|
+
# TODO : ADD TESTS WITH BACKGROUND
|
|
242
262
|
if register:
|
|
243
|
-
self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
244
|
-
|
|
245
|
-
|
|
263
|
+
self.inference_data.posterior[f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
264
|
+
xr.DataArray(
|
|
265
|
+
value,
|
|
266
|
+
dims=self.inference_data.posterior.dims,
|
|
267
|
+
coords={
|
|
268
|
+
"chain": self.inference_data.posterior.chain,
|
|
269
|
+
"draw": self.inference_data.posterior.draw,
|
|
270
|
+
},
|
|
271
|
+
name=f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}",
|
|
272
|
+
)
|
|
246
273
|
)
|
|
247
274
|
|
|
248
275
|
return value
|
|
@@ -257,6 +284,8 @@ class FitResult:
|
|
|
257
284
|
cosmology: Cosmology = Planck18,
|
|
258
285
|
unit: Unit = u.erg / u.s,
|
|
259
286
|
register: bool = False,
|
|
287
|
+
n_points: int = 5,
|
|
288
|
+
n_grid: int = 1_000,
|
|
260
289
|
) -> ArrayLike:
|
|
261
290
|
"""
|
|
262
291
|
Compute the luminosity of the source specifying its redshift. The luminosity is then added to
|
|
@@ -270,8 +299,12 @@ class FitResult:
|
|
|
270
299
|
cosmology: Chosen cosmology.
|
|
271
300
|
unit: The unit of the luminosity.
|
|
272
301
|
register: Whether to register the flux with the other posterior parameters.
|
|
302
|
+
n_points: The number of points per bin to use for computing the unfolded spectrum.
|
|
303
|
+
n_grid: The number of grid points to use for computing the unfolded spectrum.
|
|
273
304
|
"""
|
|
274
305
|
|
|
306
|
+
energy_grid = np.linspace(e_min, e_max, n_grid)
|
|
307
|
+
|
|
275
308
|
if not observer_frame:
|
|
276
309
|
raise NotImplementedError()
|
|
277
310
|
|
|
@@ -292,19 +325,28 @@ class FitResult:
|
|
|
292
325
|
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
293
326
|
return self.model.energy_flux(
|
|
294
327
|
parameters_pytree,
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
n_points=
|
|
298
|
-
)
|
|
328
|
+
energy_grid[:-1] * (1 + redshift),
|
|
329
|
+
energy_grid[1:] * (1 + redshift),
|
|
330
|
+
n_points=n_points,
|
|
331
|
+
)
|
|
299
332
|
|
|
300
333
|
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
301
|
-
flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
|
|
302
|
-
value =
|
|
334
|
+
flux = vectorized_flux(*flat_tree).sum(axis=-1) * (u.keV / u.cm**2 / u.s)
|
|
335
|
+
value = np.asarray(
|
|
336
|
+
(flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
|
|
337
|
+
)
|
|
303
338
|
|
|
304
339
|
if register:
|
|
305
|
-
self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
|
|
306
|
-
|
|
307
|
-
|
|
340
|
+
self.inference_data.posterior[f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}"] = (
|
|
341
|
+
xr.DataArray(
|
|
342
|
+
value,
|
|
343
|
+
dims=self.inference_data.posterior.dims,
|
|
344
|
+
coords={
|
|
345
|
+
"chain": self.inference_data.posterior.chain,
|
|
346
|
+
"draw": self.inference_data.posterior.draw,
|
|
347
|
+
},
|
|
348
|
+
name=f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}",
|
|
349
|
+
)
|
|
308
350
|
)
|
|
309
351
|
|
|
310
352
|
return value
|
|
@@ -315,10 +357,13 @@ class FitResult:
|
|
|
315
357
|
|
|
316
358
|
Parameters:
|
|
317
359
|
name: The name of the chain.
|
|
360
|
+
parameter_kind: The kind of parameters to keep.
|
|
318
361
|
"""
|
|
319
362
|
|
|
320
363
|
keys_to_drop = [
|
|
321
|
-
key
|
|
364
|
+
key
|
|
365
|
+
for key in self.inference_data.posterior.keys()
|
|
366
|
+
if not key.startswith(parameter_kind)
|
|
322
367
|
]
|
|
323
368
|
|
|
324
369
|
reduced_id = az.extract(
|
|
@@ -403,6 +448,7 @@ class FitResult:
|
|
|
403
448
|
title: str | None = None,
|
|
404
449
|
figsize: tuple[float, float] = (6, 6),
|
|
405
450
|
x_lims: tuple[float, float] | None = None,
|
|
451
|
+
rescale_background: bool = False,
|
|
406
452
|
) -> list[plt.Figure]:
|
|
407
453
|
r"""
|
|
408
454
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
@@ -423,6 +469,7 @@ class FitResult:
|
|
|
423
469
|
title: The title of the plot.
|
|
424
470
|
figsize: The size of the figure.
|
|
425
471
|
x_lims: The limits of the x-axis.
|
|
472
|
+
rescale_background: Whether to rescale the background model to the data with backscal ratio.
|
|
426
473
|
|
|
427
474
|
Returns:
|
|
428
475
|
A list of matplotlib figures for each observation in the model.
|
|
@@ -573,10 +620,14 @@ class FitResult:
|
|
|
573
620
|
)
|
|
574
621
|
)
|
|
575
622
|
|
|
623
|
+
rescale_background_factor = (
|
|
624
|
+
obsconf.folded_backratio.data if rescale_background else 1.0
|
|
625
|
+
)
|
|
626
|
+
|
|
576
627
|
model_bkg_plot = _plot_binned_samples_with_error(
|
|
577
628
|
ax[0],
|
|
578
629
|
xbins.value,
|
|
579
|
-
y_samples_bkg.value,
|
|
630
|
+
y_samples_bkg.value * rescale_background_factor,
|
|
580
631
|
color=BACKGROUND_COLOR,
|
|
581
632
|
alpha_envelope=alpha_envelope,
|
|
582
633
|
n_sigmas=n_sigmas,
|
|
@@ -585,9 +636,9 @@ class FitResult:
|
|
|
585
636
|
true_bkg_plot = _plot_poisson_data_with_error(
|
|
586
637
|
ax[0],
|
|
587
638
|
xbins.value,
|
|
588
|
-
y_observed_bkg.value,
|
|
589
|
-
y_observed_bkg_low.value,
|
|
590
|
-
y_observed_bkg_high.value,
|
|
639
|
+
y_observed_bkg.value * rescale_background_factor,
|
|
640
|
+
y_observed_bkg_low.value * rescale_background_factor,
|
|
641
|
+
y_observed_bkg_high.value * rescale_background_factor,
|
|
591
642
|
color=BACKGROUND_DATA_COLOR,
|
|
592
643
|
alpha=0.7,
|
|
593
644
|
)
|
jaxspec/data/__init__.py
CHANGED
|
@@ -6,5 +6,7 @@ from .observation import Observation
|
|
|
6
6
|
|
|
7
7
|
u.add_enabled_aliases({"counts": u.count})
|
|
8
8
|
u.add_enabled_aliases({"channel": u.dimensionless_unscaled})
|
|
9
|
+
u.add_enabled_aliases({"ADU": u.dimensionless_unscaled}) # Appears in SIXTE outputs
|
|
10
|
+
|
|
9
11
|
# Arbitrary units are found in .rsp files , let's hope it is compatible with what we would expect as the rmf x arf
|
|
10
12
|
# u.add_enabled_aliases({"au": u.dimensionless_unscaled})
|
jaxspec/data/instrument.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import sparse
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import xarray as xr
|
|
@@ -92,7 +93,7 @@ class Instrument(xr.Dataset):
|
|
|
92
93
|
|
|
93
94
|
else:
|
|
94
95
|
specresp = rmf.matrix.sum(axis=0)
|
|
95
|
-
rmf.sparse_matrix
|
|
96
|
+
rmf.sparse_matrix = sparse.COO( rmf.matrix / specresp )
|
|
96
97
|
|
|
97
98
|
return cls.from_matrix(
|
|
98
99
|
rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
|
jaxspec/data/util.py
CHANGED
|
@@ -152,11 +152,11 @@ def forward_model_with_multiple_inputs(
|
|
|
152
152
|
transfer_matrix = BCOO.from_scipy_sparse(
|
|
153
153
|
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
154
154
|
)
|
|
155
|
+
expected_counts = transfer_matrix @ flux_func(parameters).T
|
|
155
156
|
|
|
156
157
|
else:
|
|
157
158
|
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
158
|
-
|
|
159
|
-
expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
|
|
159
|
+
expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
|
|
160
160
|
|
|
161
161
|
# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
|
|
162
162
|
return jnp.clip(expected_counts, a_min=1e-6)
|
jaxspec/fit/_fitter.py
CHANGED
|
@@ -9,11 +9,13 @@ import matplotlib.pyplot as plt
|
|
|
9
9
|
import numpyro
|
|
10
10
|
|
|
11
11
|
from jax import random
|
|
12
|
+
from jax.numpy import concatenate
|
|
12
13
|
from jax.random import PRNGKey
|
|
13
14
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, SVI, Predictive, Trace_ELBO
|
|
14
15
|
from numpyro.infer.autoguide import AutoMultivariateNormal
|
|
15
16
|
|
|
16
17
|
from ..analysis.results import FitResult
|
|
18
|
+
from ..model.background import SubtractedBackground
|
|
17
19
|
from ._bayesian_model import BayesianModel
|
|
18
20
|
|
|
19
21
|
|
|
@@ -52,9 +54,22 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
52
54
|
)
|
|
53
55
|
|
|
54
56
|
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
57
|
+
if len(log_likelihood.keys()) > 1:
|
|
58
|
+
log_likelihood["full"] = concatenate([ll for _, ll in log_likelihood.items()], axis=1)
|
|
59
|
+
log_likelihood["obs/~/all"] = concatenate(
|
|
60
|
+
[ll for k, ll in log_likelihood.items() if "obs" in k], axis=1
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Subtracted background is not fitted so there is no likelihood
|
|
64
|
+
if self.background_model is not None and not isinstance(
|
|
65
|
+
self.background_model, SubtractedBackground
|
|
66
|
+
):
|
|
67
|
+
log_likelihood["bkg/~/all"] = concatenate(
|
|
68
|
+
[ll for k, ll in log_likelihood.items() if "bkg" in k], axis=1
|
|
69
|
+
)
|
|
55
70
|
|
|
56
71
|
seeded_model = numpyro.handlers.substitute(
|
|
57
|
-
numpyro.handlers.seed(numpyro_model, keys[
|
|
72
|
+
numpyro.handlers.seed(numpyro_model, keys[2]),
|
|
58
73
|
substitute_fn=numpyro.infer.init_to_sample,
|
|
59
74
|
)
|
|
60
75
|
|
|
@@ -108,12 +123,10 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
108
123
|
predictive_parameters = []
|
|
109
124
|
|
|
110
125
|
for key, value in self._observation_container.items():
|
|
126
|
+
predictive_parameters.append(f"obs/~/{key}")
|
|
111
127
|
if self.background_model is not None:
|
|
112
|
-
predictive_parameters.append(f"obs/~/{key}")
|
|
113
128
|
predictive_parameters.append(f"bkg/~/{key}")
|
|
114
129
|
# predictive_parameters.append(f"ins/~/{key}")
|
|
115
|
-
else:
|
|
116
|
-
predictive_parameters.append(f"obs/~/{key}")
|
|
117
130
|
# predictive_parameters.append(f"ins/~/{key}")
|
|
118
131
|
|
|
119
132
|
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
@@ -247,7 +260,7 @@ class VIFitter(BayesianModelFitter):
|
|
|
247
260
|
|
|
248
261
|
svi = SVI(bayesian_model, guide, optimizer, loss=loss)
|
|
249
262
|
|
|
250
|
-
keys = random.split(random.PRNGKey(rng_key),
|
|
263
|
+
keys = random.split(random.PRNGKey(rng_key), 2)
|
|
251
264
|
svi_result = svi.run(keys[0], num_steps)
|
|
252
265
|
params = svi_result.params
|
|
253
266
|
|
jaxspec/model/abc.py
CHANGED
|
@@ -372,9 +372,10 @@ class AdditiveComponent(ModelComponent):
|
|
|
372
372
|
continuum = self.continuum(energy)
|
|
373
373
|
integrated_continuum = self.integrated_continuum(e_low, e_high)
|
|
374
374
|
|
|
375
|
-
return
|
|
376
|
-
continuum * energy**2, jnp.log(energy), axis=-1
|
|
377
|
-
|
|
375
|
+
return (
|
|
376
|
+
jsp.integrate.trapezoid(continuum * energy**2, jnp.log(energy), axis=-1)
|
|
377
|
+
+ integrated_continuum * (e_high + e_low) / 2.0
|
|
378
|
+
)
|
|
378
379
|
|
|
379
380
|
@partial(jax.jit, static_argnums=0, static_argnames="n_points")
|
|
380
381
|
def photon_flux(self, params, e_low, e_high, n_points=2):
|
jaxspec/model/additive.py
CHANGED
|
@@ -34,8 +34,10 @@ class Powerlaw(AdditiveComponent):
|
|
|
34
34
|
self.alpha = nnx.Param(1.7)
|
|
35
35
|
self.norm = nnx.Param(1e-4)
|
|
36
36
|
|
|
37
|
-
def
|
|
38
|
-
return
|
|
37
|
+
def integrated_continuum(self, e_low, e_high):
|
|
38
|
+
return (
|
|
39
|
+
self.norm / (1 - self.alpha) * (e_high ** (1 - self.alpha) - e_low ** (1 - self.alpha))
|
|
40
|
+
)
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class Additiveconstant(AdditiveComponent):
|
|
@@ -166,28 +168,22 @@ class Gauss(AdditiveComponent):
|
|
|
166
168
|
self.sigma = nnx.Param(1e-2)
|
|
167
169
|
self.norm = nnx.Param(1.0)
|
|
168
170
|
|
|
169
|
-
def
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
171
|
+
def integrated_continuum(self, e_low, e_high):
|
|
172
|
+
upper = jsp.stats.norm.cdf(
|
|
173
|
+
e_high,
|
|
174
|
+
loc=jnp.asarray(self.El),
|
|
175
|
+
scale=jnp.asarray(self.sigma),
|
|
174
176
|
)
|
|
175
177
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
e_high,
|
|
181
|
-
loc=jnp.asarray(self.El),
|
|
182
|
-
scale=jnp.asarray(self.sigma),
|
|
183
|
-
)
|
|
184
|
-
- jsp.stats.norm.cdf(
|
|
185
|
-
e_low,
|
|
186
|
-
loc=jnp.asarray(self.El),
|
|
187
|
-
scale=jnp.asarray(self.sigma),
|
|
188
|
-
) #/ (1 - jsp.special.erf(- self.El / (self.sigma * jnp.sqrt(2))))
|
|
178
|
+
lower = jsp.stats.norm.cdf(
|
|
179
|
+
e_low,
|
|
180
|
+
loc=jnp.asarray(self.El),
|
|
181
|
+
scale=jnp.asarray(self.sigma),
|
|
189
182
|
)
|
|
190
|
-
|
|
183
|
+
|
|
184
|
+
factor = 2 / (1 - jsp.special.erf(-self.El / (self.sigma * jnp.sqrt(2))))
|
|
185
|
+
|
|
186
|
+
return self.norm * (upper - lower) * factor
|
|
191
187
|
|
|
192
188
|
|
|
193
189
|
class Cutoffpl(AdditiveComponent):
|
jaxspec/model/multiplicative.py
CHANGED
|
@@ -49,7 +49,9 @@ class Expfac(MultiplicativeComponent):
|
|
|
49
49
|
self.E_c = nnx.Param(1.0)
|
|
50
50
|
|
|
51
51
|
def factor(self, energy):
|
|
52
|
-
return jnp.where(
|
|
52
|
+
return jnp.where(
|
|
53
|
+
energy >= self.E_c, 1.0 + self.A * jnp.exp(-self.f * energy), 1.0
|
|
54
|
+
)
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
class Tbabs(MultiplicativeComponent):
|
|
@@ -91,6 +93,49 @@ class Tbabs(MultiplicativeComponent):
|
|
|
91
93
|
return jnp.exp(-self.nh * sigma)
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
class zTbabs(MultiplicativeComponent):
|
|
97
|
+
r"""
|
|
98
|
+
The redshifted Tuebingen-Boulder ISM absorption model. See `Tbabs` for more details.
|
|
99
|
+
From Xspec manual:
|
|
100
|
+
This model assumes that 20% of the hydrogen is molecular
|
|
101
|
+
and that there is NO MATERIAL IN GRAINS.
|
|
102
|
+
|
|
103
|
+
$$
|
|
104
|
+
\mathcal{M}(E) = \exp^{-N_{\text{H}}\sigma(E)}
|
|
105
|
+
$$
|
|
106
|
+
|
|
107
|
+
!!! abstract "Parameters"
|
|
108
|
+
* $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
|
|
109
|
+
* $z$ (`z`) $\left[\text{dimensionless}\right]$ : Redshift
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
!!! note
|
|
113
|
+
Abundances and cross-sections $\sigma$ can be found in Wilms et al. (2000).
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self):
|
|
118
|
+
table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
|
|
119
|
+
self._energy = np.asarray(table["ENERGY"], dtype=np.float64)
|
|
120
|
+
self._sigma = np.asarray(table["SIGMA"], dtype=np.float64)
|
|
121
|
+
self.nh = nnx.Param(1.0)
|
|
122
|
+
self.z = nnx.Param(1.0)
|
|
123
|
+
|
|
124
|
+
def factor(self, energy):
|
|
125
|
+
z = jnp.asarray(self.z)
|
|
126
|
+
sigma = jnp.exp(
|
|
127
|
+
jnp.interp(
|
|
128
|
+
jnp.log(energy) + jnp.log1p(z),
|
|
129
|
+
jnp.log(self._energy),
|
|
130
|
+
jnp.log(self._sigma),
|
|
131
|
+
left="extrapolate",
|
|
132
|
+
right="extrapolate",
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return jnp.exp(-self.nh * sigma)
|
|
137
|
+
|
|
138
|
+
|
|
94
139
|
class Phabs(MultiplicativeComponent):
|
|
95
140
|
r"""
|
|
96
141
|
A photoelectric absorption model.
|
|
@@ -215,7 +260,9 @@ class Zedge(MultiplicativeComponent):
|
|
|
215
260
|
|
|
216
261
|
def factor(self, energy):
|
|
217
262
|
return jnp.where(
|
|
218
|
-
energy <= self.Ec,
|
|
263
|
+
energy <= self.Ec,
|
|
264
|
+
1.0,
|
|
265
|
+
jnp.exp(-self.D * (energy * (1 + self.z) / self.Ec) ** 3),
|
|
219
266
|
)
|
|
220
267
|
|
|
221
268
|
|
|
@@ -246,7 +293,11 @@ class Tbpcf(MultiplicativeComponent):
|
|
|
246
293
|
def factor(self, energy):
|
|
247
294
|
sigma = jnp.exp(
|
|
248
295
|
jnp.interp(
|
|
249
|
-
energy,
|
|
296
|
+
energy,
|
|
297
|
+
self._energy,
|
|
298
|
+
jnp.log(self._sigma),
|
|
299
|
+
left="extrapolate",
|
|
300
|
+
right="extrapolate",
|
|
250
301
|
)
|
|
251
302
|
)
|
|
252
303
|
|
jaxspec/util/online_storage.py
CHANGED
|
@@ -25,4 +25,5 @@ table_manager = pooch.create(
|
|
|
25
25
|
"example_data/NGC7793_ULX4/MOS2.arf": "sha256:a126ff5a95a5f4bb93ed846944cf411d6e1c448626cb73d347e33324663d8b3f",
|
|
26
26
|
"example_data/NGC7793_ULX4/PNbackground_spectrum.fits": "sha256:55e017e0c19b324245fef049dff2a7a2e49b9a391667ca9c4f667c4f683b1f49",
|
|
27
27
|
},
|
|
28
|
+
retry_if_failed=10,
|
|
28
29
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
|
|
@@ -8,7 +8,7 @@ Author-email: sdupourque <sdupourque@irap.omp.eu>
|
|
|
8
8
|
License-Expression: MIT
|
|
9
9
|
License-File: LICENSE.md
|
|
10
10
|
Requires-Python: <3.13,>=3.10
|
|
11
|
-
Requires-Dist: arviz<0.
|
|
11
|
+
Requires-Dist: arviz<0.24.0,>=0.17.1
|
|
12
12
|
Requires-Dist: astropy<8,>=6.0.0
|
|
13
13
|
Requires-Dist: catppuccin<3,>=2.3.4
|
|
14
14
|
Requires-Dist: chainconsumer<2,>=1.1.2
|
|
@@ -16,19 +16,15 @@ Requires-Dist: cmasher<2,>=1.6.3
|
|
|
16
16
|
Requires-Dist: flax>0.10.5
|
|
17
17
|
Requires-Dist: interpax<0.4,>=0.3.5
|
|
18
18
|
Requires-Dist: jax<0.7,>=0.5.0
|
|
19
|
-
Requires-Dist: jaxns<3,>=2.6.7
|
|
20
|
-
Requires-Dist: jaxopt<0.9,>=0.8.3
|
|
21
19
|
Requires-Dist: matplotlib<4,>=3.8.0
|
|
22
20
|
Requires-Dist: mendeleev<1.2,>=0.15
|
|
23
21
|
Requires-Dist: networkx~=3.1
|
|
24
22
|
Requires-Dist: numpy<3.0.0
|
|
25
|
-
Requires-Dist: numpyro<0.
|
|
26
|
-
Requires-Dist: optimistix<0.0.12,>=0.0.10
|
|
23
|
+
Requires-Dist: numpyro<0.21,>=0.17.0
|
|
27
24
|
Requires-Dist: pandas<3,>=2.2.0
|
|
28
25
|
Requires-Dist: pooch<2,>=1.8.2
|
|
29
26
|
Requires-Dist: scipy<1.16
|
|
30
|
-
Requires-Dist: seaborn
|
|
31
|
-
Requires-Dist: simpleeval<1.1.0,>=0.9.13
|
|
27
|
+
Requires-Dist: seaborn>=0.13.2
|
|
32
28
|
Requires-Dist: sparse>0.15
|
|
33
29
|
Requires-Dist: tinygp<0.4,>=0.3.0
|
|
34
30
|
Requires-Dist: watermark<3,>=2.4.3
|
|
@@ -2,13 +2,13 @@ jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
|
|
|
2
2
|
jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
jaxspec/analysis/_plot.py,sha256=0xEz-e_xk7XvU6PUfbNwxaWg1-SxAF2XAqhkxWEhIFs,6239
|
|
4
4
|
jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
|
|
5
|
-
jaxspec/analysis/results.py,sha256=
|
|
6
|
-
jaxspec/data/__init__.py,sha256=
|
|
7
|
-
jaxspec/data/instrument.py,sha256=
|
|
5
|
+
jaxspec/analysis/results.py,sha256=nZ7JORgA6YYei8hRHmGvhUmnSpd11FfryP-E7UVVT9s,28650
|
|
6
|
+
jaxspec/data/__init__.py,sha256=9fZRyB3eXdEi_ZsTcHT64xH3kW3jQFD0XS5eIAD0RDo,501
|
|
7
|
+
jaxspec/data/instrument.py,sha256=weiPcEll1jZM6lqhxpF1aPIRwvaP6bygSB8jLBABXto,4815
|
|
8
8
|
jaxspec/data/obsconf.py,sha256=bkYuD6mJgj8QmRaDVhcnXwUukVdo20xllzaI57prHag,10056
|
|
9
9
|
jaxspec/data/observation.py,sha256=7FHJm1jHEEFyrqxg3COsGmfdh5dg-5XnfKCp1yb5fNY,7411
|
|
10
10
|
jaxspec/data/ogip.py,sha256=eMmBuROW4eMRxRHkPPyGHf933e0IcREqB8WMQFMS2lY,9810
|
|
11
|
-
jaxspec/data/util.py,sha256=
|
|
11
|
+
jaxspec/data/util.py,sha256=2JWoHsKJqGXUn74zPeoAqdU86x2n8NyfZGvpqC21ZaY,9832
|
|
12
12
|
jaxspec/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
jaxspec/experimental/interpolator.py,sha256=mJRdCB4B71le3dQL_S_E6Wkqpb6QLT7Wdzlok-rU6Ok,2652
|
|
14
14
|
jaxspec/experimental/interpolator_jax.py,sha256=13lflsjbImDRZTObSRDtZnujrXBvEP367Rn20eByONs,2967
|
|
@@ -18,25 +18,25 @@ jaxspec/experimental/tabulated.py,sha256=H0llUiso2KGH4xUzTUSVPy-6I8D3wm707lU_Z1P
|
|
|
18
18
|
jaxspec/fit/__init__.py,sha256=OaS0-Hkb3Hd-AkE2o-KWfoWMX0NSCPY-_FP2znHf9l0,153
|
|
19
19
|
jaxspec/fit/_bayesian_model.py,sha256=7c2Twgz06QV1S9DdctdVk5YT1v7P-ln100bWXAvv7Go,15179
|
|
20
20
|
jaxspec/fit/_build_model.py,sha256=pNZVuVfwOq3Pg23opH7xRv28DsSkQZpvy2Z-1hQSfNs,3219
|
|
21
|
-
jaxspec/fit/_fitter.py,sha256=
|
|
21
|
+
jaxspec/fit/_fitter.py,sha256=lkLFvwqE6NGpBk4i8gdTHGD8QQahCp9OgJS96j7N6FA,10383
|
|
22
22
|
jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
23
|
jaxspec/model/_graph_util.py,sha256=hPvHYmAxb7P3nyIecaZ7RqWOjwcZ1WvUByt_yNANiaY,4552
|
|
24
|
-
jaxspec/model/abc.py,sha256=
|
|
25
|
-
jaxspec/model/additive.py,sha256=
|
|
24
|
+
jaxspec/model/abc.py,sha256=vvHM4teepc8VLqbpAtqf1b55oF00R_Lo_6nrBO5KmmQ,14793
|
|
25
|
+
jaxspec/model/additive.py,sha256=MGHqJ4Ai2kPkMrZOLom4vTNyeg2oczpg4A9Rv5JdU34,20366
|
|
26
26
|
jaxspec/model/background.py,sha256=VLSrU0YCW9GSHCtaEdcth-sp74aPyEVSizIMFkTpM7M,7759
|
|
27
27
|
jaxspec/model/instrument.py,sha256=1zLZgHmBZs8RLKTMT3Wu4bCx6JnxBUjhRIpYG2rLaZM,2947
|
|
28
28
|
jaxspec/model/list.py,sha256=uC9rLEEeph10q6shat86WLACVuTSx73RGMl8Ij0jqQY,875
|
|
29
|
-
jaxspec/model/multiplicative.py,sha256=
|
|
29
|
+
jaxspec/model/multiplicative.py,sha256=L4zyCrYU364Cwb2bDJc9ydSc_EMzNQ21zNWIM8EbLKE,9793
|
|
30
30
|
jaxspec/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
31
|
jaxspec/scripts/debug.py,sha256=qhyDtX4G5UdChmTLCM-5Wti4XZU-sU5S-wDb6TZjrvM,292
|
|
32
32
|
jaxspec/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
33
|
jaxspec/util/abundance.py,sha256=fsC313taIlGzQsZNwbYsJupDWm7ZbqzGhY66Ku394Mw,8546
|
|
34
34
|
jaxspec/util/integrate.py,sha256=7GwBSagmDzsF3P53tPs-oakeq0zHEwmZZS2zQlXngbE,4634
|
|
35
35
|
jaxspec/util/misc.py,sha256=O3qorCL1Y2X1BS2jdd36C1eDHK9QDXTSOr9kj3uqcJo,654
|
|
36
|
-
jaxspec/util/online_storage.py,sha256=
|
|
36
|
+
jaxspec/util/online_storage.py,sha256=wwpowxmDgAqKzeUwmGUIxttA4VKUoR270Ew-F_0DrkE,2493
|
|
37
37
|
jaxspec/util/typing.py,sha256=ZQM_l68qyYnIBZPz_1mKvwPMx64jvVBD8Uj6bx9sHv0,140
|
|
38
|
-
jaxspec-0.3.
|
|
39
|
-
jaxspec-0.3.
|
|
40
|
-
jaxspec-0.3.
|
|
41
|
-
jaxspec-0.3.
|
|
42
|
-
jaxspec-0.3.
|
|
38
|
+
jaxspec-0.3.4.dist-info/METADATA,sha256=VzNrPKy8_5ReZk5REcwADXU-m8QpiICd-Gx9ow_1X2w,4045
|
|
39
|
+
jaxspec-0.3.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
40
|
+
jaxspec-0.3.4.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
|
|
41
|
+
jaxspec-0.3.4.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
|
|
42
|
+
jaxspec-0.3.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|