jaxspec 0.2.0__tar.gz → 0.2.1.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/PKG-INFO +3 -3
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/pyproject.toml +4 -4
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/results.py +14 -4
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/util.py +68 -8
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/fit.py +22 -2
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/additive.py +13 -13
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/multiplicative.py +3 -3
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/LICENSE.md +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/README.md +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/_fit/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/_fit/_build_model.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/_plot.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/compare.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/instrument.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/obsconf.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/observation.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/ogip.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/_graph_util.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/abc.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/background.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/list.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/scripts/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/scripts/debug.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/__init__.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/abundance.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/integrate.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/misc.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/online_storage.py +0 -0
- {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/typing.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1.dev2
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
License: MIT
|
|
6
6
|
Author: sdupourque
|
|
@@ -18,7 +18,7 @@ Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
|
|
|
18
18
|
Requires-Dist: cmasher (>=1.6.3,<2.0.0)
|
|
19
19
|
Requires-Dist: flax (>=0.10.1,<0.11.0)
|
|
20
20
|
Requires-Dist: interpax (>=0.3.3,<0.4.0)
|
|
21
|
-
Requires-Dist: jax (>=0.
|
|
21
|
+
Requires-Dist: jax (>=0.5.0,<0.6.0)
|
|
22
22
|
Requires-Dist: jaxns (>=2.6.7,<3.0.0)
|
|
23
23
|
Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
|
|
24
24
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
|
@@ -26,7 +26,7 @@ Requires-Dist: mendeleev (>=0.15,<0.20)
|
|
|
26
26
|
Requires-Dist: networkx (>=3.1,<4.0)
|
|
27
27
|
Requires-Dist: numpy (<2.0.0)
|
|
28
28
|
Requires-Dist: numpyro (>=0.16.1,<0.17.0)
|
|
29
|
-
Requires-Dist: optimistix (>=0.0.7,<0.0.
|
|
29
|
+
Requires-Dist: optimistix (>=0.0.7,<0.0.11)
|
|
30
30
|
Requires-Dist: pandas (>=2.2.0,<3.0.0)
|
|
31
31
|
Requires-Dist: pooch (>=1.8.2,<2.0.0)
|
|
32
32
|
Requires-Dist: scipy (<1.15)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "jaxspec"
|
|
3
|
-
version = "0.2.
|
|
3
|
+
version = "0.2.1dev-2"
|
|
4
4
|
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
|
|
5
5
|
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -11,7 +11,7 @@ documentation = "https://jaxspec.readthedocs.io/en/latest/"
|
|
|
11
11
|
|
|
12
12
|
[tool.poetry.dependencies]
|
|
13
13
|
python = ">=3.10,<3.13"
|
|
14
|
-
jax = "^0.
|
|
14
|
+
jax = "^0.5.0"
|
|
15
15
|
numpy = "<2.0.0"
|
|
16
16
|
pandas = "^2.2.0"
|
|
17
17
|
astropy = "^6.0.0"
|
|
@@ -26,7 +26,7 @@ jaxopt = "^0.8.1"
|
|
|
26
26
|
tinygp = "^0.3.0"
|
|
27
27
|
seaborn = "^0.13.1"
|
|
28
28
|
sparse = "^0.15.4"
|
|
29
|
-
optimistix = ">=0.0.7,<0.0.
|
|
29
|
+
optimistix = ">=0.0.7,<0.0.11"
|
|
30
30
|
scipy = "<1.15"
|
|
31
31
|
mendeleev = ">=0.15,<0.20"
|
|
32
32
|
jaxns = "^2.6.7"
|
|
@@ -56,7 +56,7 @@ testbook = "^0.4.2"
|
|
|
56
56
|
|
|
57
57
|
[tool.poetry.group.dev.dependencies]
|
|
58
58
|
pre-commit = ">=3.5,<5.0"
|
|
59
|
-
ruff = ">=0.2.1,<0.
|
|
59
|
+
ruff = ">=0.2.1,<0.10.0"
|
|
60
60
|
jupyterlab = "^4.0.7"
|
|
61
61
|
notebook = "^7.0.6"
|
|
62
62
|
ipywidgets = "^8.1.1"
|
|
@@ -391,6 +391,8 @@ class FitResult:
|
|
|
391
391
|
alpha_envelope: (float, float) = (0.15, 0.25),
|
|
392
392
|
style: str | Any = "default",
|
|
393
393
|
title: str | None = None,
|
|
394
|
+
figsize: tuple[float, float] = (6, 6),
|
|
395
|
+
x_lims: tuple[float, float] | None = None,
|
|
394
396
|
) -> list[plt.Figure]:
|
|
395
397
|
r"""
|
|
396
398
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
@@ -400,7 +402,7 @@ class FitResult:
|
|
|
400
402
|
{(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
|
|
401
403
|
|
|
402
404
|
Parameters:
|
|
403
|
-
|
|
405
|
+
n_sigmas: The number of sigmas to plot the envelops.
|
|
404
406
|
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
405
407
|
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
|
|
406
408
|
plot_background: Whether to plot the background model if it is included in the fit.
|
|
@@ -408,6 +410,9 @@ class FitResult:
|
|
|
408
410
|
scale: The axes scaling
|
|
409
411
|
alpha_envelope: The transparency range for envelops
|
|
410
412
|
style: The style of the plot. It can be either a string or a matplotlib style context.
|
|
413
|
+
title: The title of the plot.
|
|
414
|
+
figsize: The size of the figure.
|
|
415
|
+
x_lims: The limits of the x-axis.
|
|
411
416
|
|
|
412
417
|
Returns:
|
|
413
418
|
A list of matplotlib figures for each observation in the model.
|
|
@@ -436,7 +441,7 @@ class FitResult:
|
|
|
436
441
|
fig, ax = plt.subplots(
|
|
437
442
|
2,
|
|
438
443
|
1,
|
|
439
|
-
figsize=
|
|
444
|
+
figsize=figsize,
|
|
440
445
|
sharex="col",
|
|
441
446
|
height_ratios=[0.7, 0.3],
|
|
442
447
|
)
|
|
@@ -525,8 +530,10 @@ class FitResult:
|
|
|
525
530
|
alpha_envelope=alpha_envelope,
|
|
526
531
|
)
|
|
527
532
|
|
|
533
|
+
name = component_name.split("*")[-1]
|
|
534
|
+
|
|
528
535
|
legend_plots += component_plot
|
|
529
|
-
legend_labels.append(
|
|
536
|
+
legend_labels.append(name)
|
|
530
537
|
|
|
531
538
|
if self.background_model is not None and plot_background:
|
|
532
539
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
@@ -617,6 +624,9 @@ class FitResult:
|
|
|
617
624
|
ax[0].set_xscale("log")
|
|
618
625
|
ax[0].set_yscale("log")
|
|
619
626
|
|
|
627
|
+
if x_lims is not None:
|
|
628
|
+
ax[0].set_xlim(*x_lims)
|
|
629
|
+
|
|
620
630
|
fig.align_ylabels()
|
|
621
631
|
plt.subplots_adjust(hspace=0.0)
|
|
622
632
|
fig.tight_layout()
|
|
@@ -654,7 +664,7 @@ class FitResult:
|
|
|
654
664
|
"""
|
|
655
665
|
|
|
656
666
|
consumer = ChainConsumer()
|
|
657
|
-
consumer.add_chain(self.to_chain(
|
|
667
|
+
consumer.add_chain(self.to_chain("Results"))
|
|
658
668
|
consumer.set_plot_config(config)
|
|
659
669
|
|
|
660
670
|
# Context for default mpl style
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Literal, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Literal, TypeVar
|
|
4
4
|
|
|
5
5
|
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
6
8
|
import numpyro
|
|
7
9
|
|
|
8
10
|
from astropy.io import fits
|
|
11
|
+
from jax.experimental.sparse import BCOO
|
|
9
12
|
from numpyro import handlers
|
|
10
13
|
|
|
11
|
-
from .._fit._build_model import forward_model
|
|
12
14
|
from ..model.abc import SpectralModel
|
|
13
15
|
from ..util.online_storage import table_manager
|
|
14
16
|
from . import Instrument, ObsConfiguration, Observation
|
|
@@ -16,6 +18,10 @@ from . import Instrument, ObsConfiguration, Observation
|
|
|
16
18
|
K = TypeVar("K")
|
|
17
19
|
V = TypeVar("V")
|
|
18
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from ..data import ObsConfiguration
|
|
23
|
+
from ..model.abc import SpectralModel
|
|
24
|
+
|
|
19
25
|
|
|
20
26
|
def load_example_pha(
|
|
21
27
|
source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
|
|
@@ -124,8 +130,40 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
|
|
|
124
130
|
raise ValueError(f"{source} not recognized.")
|
|
125
131
|
|
|
126
132
|
|
|
133
|
+
def forward_model_with_multiple_inputs(
|
|
134
|
+
model: "SpectralModel",
|
|
135
|
+
parameters,
|
|
136
|
+
obs_configuration: "ObsConfiguration",
|
|
137
|
+
sparse=False,
|
|
138
|
+
):
|
|
139
|
+
energies = np.asarray(obs_configuration.in_energies)
|
|
140
|
+
parameter_dims = next(iter(parameters.values())).shape
|
|
141
|
+
|
|
142
|
+
def flux_func(p):
|
|
143
|
+
return model.photon_flux(p, *energies)
|
|
144
|
+
|
|
145
|
+
for _ in parameter_dims:
|
|
146
|
+
flux_func = jax.vmap(flux_func)
|
|
147
|
+
|
|
148
|
+
flux_func = jax.jit(flux_func)
|
|
149
|
+
|
|
150
|
+
if sparse:
|
|
151
|
+
# folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
152
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
153
|
+
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
else:
|
|
157
|
+
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
158
|
+
|
|
159
|
+
expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
|
|
160
|
+
|
|
161
|
+
# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
|
|
162
|
+
return jnp.clip(expected_counts, a_min=1e-6)
|
|
163
|
+
|
|
164
|
+
|
|
127
165
|
def fakeit_for_multiple_parameters(
|
|
128
|
-
|
|
166
|
+
obsconfs: ObsConfiguration | list[ObsConfiguration],
|
|
129
167
|
model: SpectralModel,
|
|
130
168
|
parameters: Mapping[K, V],
|
|
131
169
|
rng_key: int = 0,
|
|
@@ -134,10 +172,32 @@ def fakeit_for_multiple_parameters(
|
|
|
134
172
|
):
|
|
135
173
|
"""
|
|
136
174
|
Convenience function to simulate multiple spectra from a given model and a set of parameters.
|
|
175
|
+
This is supposed to be somewhat optimized and can handle multiple parameters at once without blowing
|
|
176
|
+
up the memory. The parameters should be passed as a dictionary with the parameter name as the key and
|
|
177
|
+
the parameter values as the values, the value can be a scalar or a nd-array.
|
|
178
|
+
|
|
179
|
+
# Example:
|
|
180
|
+
|
|
181
|
+
``` python
|
|
182
|
+
from jaxspec.data.util import fakeit_for_multiple_parameters
|
|
183
|
+
from numpy.random import default_rng
|
|
184
|
+
|
|
185
|
+
rng = default_rng(42)
|
|
186
|
+
size = (10, 30)
|
|
187
|
+
|
|
188
|
+
parameters = {
|
|
189
|
+
"tbabs_1_nh": rng.uniform(0.1, 0.4, size=size),
|
|
190
|
+
"powerlaw_1_alpha": rng.uniform(1, 3, size=size),
|
|
191
|
+
"powerlaw_1_norm": rng.exponential(10 ** (-0.5), size=size),
|
|
192
|
+
"blackbodyrad_1_kT": rng.uniform(0.1, 3.0, size=size),
|
|
193
|
+
"blackbodyrad_1_norm": rng.exponential(10 ** (-3), size=size)
|
|
194
|
+
}
|
|
137
195
|
|
|
196
|
+
spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
|
|
197
|
+
```
|
|
138
198
|
|
|
139
199
|
Parameters:
|
|
140
|
-
|
|
200
|
+
obsconfs: The observational setup(s).
|
|
141
201
|
model: The model to use.
|
|
142
202
|
parameters: The parameters of the model.
|
|
143
203
|
rng_key: The random number generator seed.
|
|
@@ -145,12 +205,12 @@ def fakeit_for_multiple_parameters(
|
|
|
145
205
|
sparsify_matrix: Whether to sparsify the matrix or not.
|
|
146
206
|
"""
|
|
147
207
|
|
|
148
|
-
|
|
208
|
+
obsconf_list = [obsconfs] if isinstance(obsconfs, ObsConfiguration) else obsconfs
|
|
149
209
|
fakeits = []
|
|
150
210
|
|
|
151
|
-
for i,
|
|
152
|
-
countrate =
|
|
153
|
-
parameters
|
|
211
|
+
for i, obsconf in enumerate(obsconf_list):
|
|
212
|
+
countrate = forward_model_with_multiple_inputs(
|
|
213
|
+
model, parameters, obsconf, sparse=sparsify_matrix
|
|
154
214
|
)
|
|
155
215
|
|
|
156
216
|
if apply_stat:
|
|
@@ -13,7 +13,9 @@ import matplotlib.pyplot as plt
|
|
|
13
13
|
import numpyro
|
|
14
14
|
|
|
15
15
|
from jax import random
|
|
16
|
+
from jax.experimental import mesh_utils
|
|
16
17
|
from jax.random import PRNGKey
|
|
18
|
+
from jax.sharding import PositionalSharding
|
|
17
19
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
18
20
|
from numpyro.distributions import Poisson, TransformedDistribution
|
|
19
21
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
@@ -312,14 +314,27 @@ class BayesianModel:
|
|
|
312
314
|
Check if the prior distribution include the observed data.
|
|
313
315
|
"""
|
|
314
316
|
key_prior, key_posterior = jax.random.split(key, 2)
|
|
317
|
+
n_devices = len(jax.local_devices())
|
|
318
|
+
sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
|
|
319
|
+
|
|
320
|
+
# Sample from prior and correct if the number of samples is not a multiple of the number of devices
|
|
321
|
+
if num_samples % n_devices != 0:
|
|
322
|
+
num_samples = num_samples + n_devices - (num_samples % n_devices)
|
|
323
|
+
|
|
315
324
|
prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
|
|
316
|
-
|
|
325
|
+
|
|
326
|
+
# Split the parameters on every device
|
|
327
|
+
sharded_parameters = jax.device_put(prior_params, sharding)
|
|
328
|
+
posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
|
|
317
329
|
|
|
318
330
|
for key, value in self.observation_container.items():
|
|
319
331
|
fig, ax = plt.subplots(
|
|
320
332
|
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
|
|
321
333
|
)
|
|
322
334
|
|
|
335
|
+
legend_plots = []
|
|
336
|
+
legend_labels = []
|
|
337
|
+
|
|
323
338
|
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
324
339
|
value.folded_counts.values, 1.0, "ct"
|
|
325
340
|
)
|
|
@@ -337,6 +352,11 @@ class BayesianModel:
|
|
|
337
352
|
ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
|
|
338
353
|
)
|
|
339
354
|
|
|
355
|
+
legend_plots.append((true_data_plot,))
|
|
356
|
+
legend_labels.append("Observed")
|
|
357
|
+
legend_plots += prior_plot
|
|
358
|
+
legend_labels.append("Prior Predictive")
|
|
359
|
+
|
|
340
360
|
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
341
361
|
counts = posterior_observations["obs_" + key]
|
|
342
362
|
observed = value.folded_counts.values
|
|
@@ -363,7 +383,7 @@ class BayesianModel:
|
|
|
363
383
|
ax[1].set_ylim(0, 100)
|
|
364
384
|
ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
|
|
365
385
|
ax[0].loglog()
|
|
366
|
-
ax[0].legend(
|
|
386
|
+
ax[0].legend(legend_plots, legend_labels)
|
|
367
387
|
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
368
388
|
plt.tight_layout()
|
|
369
389
|
plt.show()
|
|
@@ -156,13 +156,13 @@ class Gauss(AdditiveComponent):
|
|
|
156
156
|
$$\mathcal{M}\left( E \right) = \frac{K}{\sigma \sqrt{2 \pi}}\exp\left(\frac{-(E-E_L)^2}{2\sigma^2}\right)$$
|
|
157
157
|
|
|
158
158
|
!!! abstract "Parameters"
|
|
159
|
-
* $E_L$ (`
|
|
159
|
+
* $E_L$ (`El`) $\left[\text{keV}\right]$ : Energy of the line
|
|
160
160
|
* $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Width of the line
|
|
161
161
|
* $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
|
|
162
162
|
"""
|
|
163
163
|
|
|
164
164
|
def __init__(self):
|
|
165
|
-
self.
|
|
165
|
+
self.El = nnx.Param(2.0)
|
|
166
166
|
self.sigma = nnx.Param(1e-2)
|
|
167
167
|
self.norm = nnx.Param(1.0)
|
|
168
168
|
|
|
@@ -170,12 +170,12 @@ class Gauss(AdditiveComponent):
|
|
|
170
170
|
return self.norm * (
|
|
171
171
|
jsp.stats.norm.cdf(
|
|
172
172
|
e_high,
|
|
173
|
-
loc=jnp.asarray(self.
|
|
173
|
+
loc=jnp.asarray(self.El, dtype=jnp.float64),
|
|
174
174
|
scale=jnp.asarray(self.sigma, dtype=jnp.float64),
|
|
175
175
|
)
|
|
176
176
|
- jsp.stats.norm.cdf(
|
|
177
177
|
e_low,
|
|
178
|
-
loc=jnp.asarray(self.
|
|
178
|
+
loc=jnp.asarray(self.El, dtype=jnp.float64),
|
|
179
179
|
scale=jnp.asarray(self.sigma, dtype=jnp.float64),
|
|
180
180
|
)
|
|
181
181
|
)
|
|
@@ -246,13 +246,13 @@ class Agauss(AdditiveComponent):
|
|
|
246
246
|
\frac{K}{\sigma \sqrt{2 \pi}} \exp\left(\frac{-(\lambda - \lambda_L)^2}{2 \sigma^2}\right)$$
|
|
247
247
|
|
|
248
248
|
!!! abstract "Parameters"
|
|
249
|
-
* $\lambda_L$ (`
|
|
249
|
+
* $\lambda_L$ (`lambdal`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
|
|
250
250
|
* $\sigma$ (`sigma`) $\left[\unicode{x212B}\right]$ : Width of the line width in Angström
|
|
251
251
|
* $K$ (`norm`) $\left[\frac{\unicode{x212B}~\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$: Normalization
|
|
252
252
|
"""
|
|
253
253
|
|
|
254
254
|
def __init__(self):
|
|
255
|
-
self.
|
|
255
|
+
self.lambdal = nnx.Param(12.0)
|
|
256
256
|
self.sigma = nnx.Param(1e-2)
|
|
257
257
|
self.norm = nnx.Param(1.0)
|
|
258
258
|
|
|
@@ -261,7 +261,7 @@ class Agauss(AdditiveComponent):
|
|
|
261
261
|
|
|
262
262
|
return self.norm * jsp.stats.norm.pdf(
|
|
263
263
|
hc / energy,
|
|
264
|
-
loc=jnp.asarray(self.
|
|
264
|
+
loc=jnp.asarray(self.lambdal, dtype=jnp.float64),
|
|
265
265
|
scale=jnp.asarray(self.sigma, dtype=jnp.float64),
|
|
266
266
|
)
|
|
267
267
|
|
|
@@ -275,14 +275,14 @@ class Zagauss(AdditiveComponent):
|
|
|
275
275
|
\frac{K (1+z)}{\sigma \sqrt{2 \pi}} \exp\left(\frac{-(\lambda/(1+z) - \lambda_L)^2}{2 \sigma^2}\right)$$
|
|
276
276
|
|
|
277
277
|
!!! abstract "Parameters"
|
|
278
|
-
* $\lambda_L$ (`
|
|
278
|
+
* $\lambda_L$ (`lambdal`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
|
|
279
279
|
* $\sigma$ (`sigma`) $\left[\unicode{x212B}\right]$ : Width of the line width in Angström
|
|
280
280
|
* $z$ (`redshift`) $\left[\text{dimensionless}\right]$ : Redshift
|
|
281
281
|
* $K$ (`norm`) $\left[\frac{\unicode{x212B}~\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization
|
|
282
282
|
"""
|
|
283
283
|
|
|
284
284
|
def __init__(self):
|
|
285
|
-
self.
|
|
285
|
+
self.lambdal = nnx.Param(12.0)
|
|
286
286
|
self.sigma = nnx.Param(1e-2)
|
|
287
287
|
self.redshift = nnx.Param(0.0)
|
|
288
288
|
self.norm = nnx.Param(1.0)
|
|
@@ -297,7 +297,7 @@ class Zagauss(AdditiveComponent):
|
|
|
297
297
|
* (1 + redshift)
|
|
298
298
|
* jsp.stats.norm.pdf(
|
|
299
299
|
(hc / energy) / (1 + redshift),
|
|
300
|
-
loc=jnp.asarray(self.
|
|
300
|
+
loc=jnp.asarray(self.lambdal, dtype=jnp.float64),
|
|
301
301
|
scale=jnp.asarray(self.sigma, dtype=jnp.float64),
|
|
302
302
|
)
|
|
303
303
|
)
|
|
@@ -311,14 +311,14 @@ class Zgauss(AdditiveComponent):
|
|
|
311
311
|
\frac{K}{(1+z) \sigma \sqrt{2 \pi}}\exp\left(\frac{-(E(1+z)-E_L)^2}{2\sigma^2}\right)$$
|
|
312
312
|
|
|
313
313
|
!!! abstract "Parameters"
|
|
314
|
-
* $E_L$ (`
|
|
314
|
+
* $E_L$ (`El`) $\left[\text{keV}\right]$ : Energy of the line
|
|
315
315
|
* $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Width of the line
|
|
316
316
|
* $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
|
|
317
317
|
* $z$ (`redshift`) $\left[\text{dimensionless}\right]$ : Redshift
|
|
318
318
|
"""
|
|
319
319
|
|
|
320
320
|
def __init__(self):
|
|
321
|
-
self.
|
|
321
|
+
self.El = nnx.Param(2.0)
|
|
322
322
|
self.sigma = nnx.Param(1e-2)
|
|
323
323
|
self.redshift = nnx.Param(0.0)
|
|
324
324
|
self.norm = nnx.Param(1.0)
|
|
@@ -326,7 +326,7 @@ class Zgauss(AdditiveComponent):
|
|
|
326
326
|
def continuum(self, energy) -> (jax.Array, jax.Array):
|
|
327
327
|
return (self.norm / (1 + self.redshift)) * jsp.stats.norm.pdf(
|
|
328
328
|
energy * (1 + self.redshift),
|
|
329
|
-
loc=jnp.asarray(self.
|
|
329
|
+
loc=jnp.asarray(self.El, dtype=jnp.float64),
|
|
330
330
|
scale=jnp.asarray(self.sigma, dtype=jnp.float64),
|
|
331
331
|
)
|
|
332
332
|
|
|
@@ -228,9 +228,9 @@ class Tbpcf(MultiplicativeComponent):
|
|
|
228
228
|
self.nh = nnx.Param(1.0)
|
|
229
229
|
self.f = nnx.Param(0.2)
|
|
230
230
|
|
|
231
|
-
def
|
|
231
|
+
def factor(self, energy):
|
|
232
232
|
sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
|
|
233
|
-
return self.f * jnp.exp(-self.nh * sigma) + (1 - self.f)
|
|
233
|
+
return self.f * jnp.exp(-self.nh * sigma) + (1.0 - self.f)
|
|
234
234
|
|
|
235
235
|
|
|
236
236
|
class FDcut(MultiplicativeComponent):
|
|
@@ -250,5 +250,5 @@ class FDcut(MultiplicativeComponent):
|
|
|
250
250
|
self.Ec = nnx.Param(1.0)
|
|
251
251
|
self.Ef = nnx.Param(3.0)
|
|
252
252
|
|
|
253
|
-
def
|
|
253
|
+
def factor(self, energy):
|
|
254
254
|
return (1 + jnp.exp((energy - self.Ec) / self.Ef)) ** -1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|