jaxspec 0.2.2.dev0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +5 -5
- jaxspec/analysis/results.py +41 -26
- jaxspec/data/obsconf.py +9 -3
- jaxspec/data/observation.py +3 -1
- jaxspec/data/ogip.py +9 -2
- jaxspec/data/util.py +17 -11
- jaxspec/experimental/interpolator.py +74 -0
- jaxspec/experimental/interpolator_jax.py +79 -0
- jaxspec/experimental/intrument_models.py +159 -0
- jaxspec/experimental/nested_sampler.py +78 -0
- jaxspec/experimental/tabulated.py +264 -0
- jaxspec/fit/__init__.py +3 -0
- jaxspec/{fit.py → fit/_bayesian_model.py} +84 -336
- jaxspec/{_fit → fit}/_build_model.py +42 -6
- jaxspec/fit/_fitter.py +255 -0
- jaxspec/model/abc.py +52 -80
- jaxspec/model/additive.py +14 -5
- jaxspec/model/background.py +17 -14
- jaxspec/model/instrument.py +81 -0
- jaxspec/model/list.py +4 -1
- jaxspec/model/multiplicative.py +32 -12
- jaxspec/util/integrate.py +17 -5
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/METADATA +11 -11
- jaxspec-0.3.1.dist-info/RECORD +42 -0
- jaxspec-0.2.2.dev0.dist-info/RECORD +0 -34
- /jaxspec/{_fit → experimental}/__init__.py +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/WHEEL +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
|
|
3
|
+
from jax import random
|
|
4
|
+
from numpyro.contrib.nested_sampling import NestedSampler
|
|
5
|
+
|
|
6
|
+
from ..analysis.results import FitResult
|
|
7
|
+
from ..fit._fitter import BayesianModelFitter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NSFitter(BayesianModelFitter):
|
|
11
|
+
r"""
|
|
12
|
+
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
13
|
+
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
14
|
+
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
15
|
+
|
|
16
|
+
!!! info
|
|
17
|
+
Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def fit(
|
|
22
|
+
self,
|
|
23
|
+
rng_key: int = 0,
|
|
24
|
+
num_samples: int = 1000,
|
|
25
|
+
num_live_points: int = 1000,
|
|
26
|
+
plot_diagnostics=False,
|
|
27
|
+
termination_kwargs: dict | None = None,
|
|
28
|
+
verbose=True,
|
|
29
|
+
use_transformed_model: bool = True,
|
|
30
|
+
) -> FitResult:
|
|
31
|
+
"""
|
|
32
|
+
Fit the model to the data using the Phantom-Powered nested sampling algorithm.
|
|
33
|
+
|
|
34
|
+
Parameters:
|
|
35
|
+
rng_key: the random key used to initialize the sampler.
|
|
36
|
+
num_samples: the number of samples to draw.
|
|
37
|
+
num_live_points: the number of live points to use at the start of the NS algorithm.
|
|
38
|
+
plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
|
|
39
|
+
termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
|
|
40
|
+
verbose: whether to print the progress of the NS algorithm.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
bayesian_model = self.transformed_numpyro_model
|
|
47
|
+
keys = random.split(random.PRNGKey(rng_key), 4)
|
|
48
|
+
|
|
49
|
+
ns = NestedSampler(
|
|
50
|
+
bayesian_model,
|
|
51
|
+
constructor_kwargs=dict(
|
|
52
|
+
verbose=verbose,
|
|
53
|
+
difficult_model=True,
|
|
54
|
+
max_samples=1e5,
|
|
55
|
+
parameter_estimation=True,
|
|
56
|
+
gradient_guided=True,
|
|
57
|
+
devices=jax.devices(),
|
|
58
|
+
# init_efficiency_threshold=0.01,
|
|
59
|
+
num_live_points=num_live_points,
|
|
60
|
+
),
|
|
61
|
+
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
ns.run(keys[0])
|
|
65
|
+
|
|
66
|
+
if plot_diagnostics:
|
|
67
|
+
ns.diagnostics()
|
|
68
|
+
|
|
69
|
+
posterior = ns.get_samples(keys[1], num_samples=num_samples)
|
|
70
|
+
inference_data = self.build_inference_data(
|
|
71
|
+
posterior, num_chains=1, use_transformed_model=use_transformed_model
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return FitResult(
|
|
75
|
+
self,
|
|
76
|
+
inference_data,
|
|
77
|
+
background_model=self.background_model,
|
|
78
|
+
)
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xarray as xr
|
|
5
|
+
|
|
6
|
+
from astropy.table import Table
|
|
7
|
+
from flax import nnx
|
|
8
|
+
from jax import lax
|
|
9
|
+
from jax.scipy.interpolate import RegularGridInterpolator
|
|
10
|
+
from jax.typing import ArrayLike
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
|
|
13
|
+
from ..model.abc import AdditiveComponent
|
|
14
|
+
from .interpolator import RegularGridInterpolatorWithGrad
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TableManager:
|
|
18
|
+
"""
|
|
19
|
+
Handler for the tabulated data from `xspec` tabulated models. The table must follow the format speficied in the
|
|
20
|
+
[`atable`](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XSmodelAtable.html) additive component.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, table_path):
|
|
24
|
+
self._table_path = table_path
|
|
25
|
+
|
|
26
|
+
# Read the tables once.
|
|
27
|
+
raw_parameters = Table.read(self._table_path, "PARAMETERS")
|
|
28
|
+
raw_energies = Table.read(self._table_path, "ENERGIES")
|
|
29
|
+
raw_spectra = Table.read(self._table_path, "SPECTRA")
|
|
30
|
+
|
|
31
|
+
# Build and wrap the parameter table.
|
|
32
|
+
parameter_table = tuple(np.asarray(row["VALUE"], dtype=float) for row in raw_parameters)
|
|
33
|
+
parameters = []
|
|
34
|
+
for row in raw_parameters:
|
|
35
|
+
parameter_name = row["NAME"].rstrip()
|
|
36
|
+
parameters.append((parameter_name, float(row["INITIAL"])))
|
|
37
|
+
|
|
38
|
+
# Process and wrap the energies table.
|
|
39
|
+
energy_low = np.asarray(raw_energies["ENERG_LO"].value, dtype=float)
|
|
40
|
+
energy_high = np.asarray(raw_energies["ENERG_HI"].value, dtype=float)
|
|
41
|
+
energies_table = np.vstack((energy_low, energy_high))
|
|
42
|
+
|
|
43
|
+
# Compute the parameter shape
|
|
44
|
+
parameter_shape = tuple()
|
|
45
|
+
for parameter in parameter_table:
|
|
46
|
+
parameter_shape += parameter.shape
|
|
47
|
+
|
|
48
|
+
# Build and wrap the spectra table.
|
|
49
|
+
total_shape = parameter_shape + energy_low.shape
|
|
50
|
+
|
|
51
|
+
spectra_table = np.empty(total_shape, dtype=float)
|
|
52
|
+
for idx, par_indexes in enumerate(np.ndindex(*parameter_shape)):
|
|
53
|
+
spectra_table[par_indexes] = np.asarray(raw_spectra["INTPSPEC"][idx], dtype=float)
|
|
54
|
+
|
|
55
|
+
self.parameters = parameters
|
|
56
|
+
self.energies_table = energies_table
|
|
57
|
+
self.parameter_table = parameter_table
|
|
58
|
+
self.parameter_shape = parameter_shape
|
|
59
|
+
self.spectra_table = spectra_table
|
|
60
|
+
|
|
61
|
+
def check_proper_ordering(self):
|
|
62
|
+
"""
|
|
63
|
+
Assert that the parameter ordering in the spectra table is consistent with the parameter table.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
spectra_table = Table.read(self._table_path, "SPECTRA")
|
|
67
|
+
pars_mesh = np.meshgrid(*self.parameter_table, indexing="ij", sparse=False)
|
|
68
|
+
|
|
69
|
+
for idx, par_indexes in enumerate(
|
|
70
|
+
tqdm(np.ndindex(*self.parameter_shape), total=len(spectra_table))
|
|
71
|
+
):
|
|
72
|
+
expected = np.asarray([par[par_indexes] for par in pars_mesh])
|
|
73
|
+
obtained = jnp.asarray(spectra_table["PARAMVAL"][idx], dtype=float)
|
|
74
|
+
assert jnp.allclose(expected, obtained), "Parameter ordering mismatch"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TabulatedModel(nnx.Module):
|
|
78
|
+
"""
|
|
79
|
+
See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/ogip_92_009/ogip_92_009.pdf
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, table_path: str):
|
|
83
|
+
table_manager = TableManager(table_path)
|
|
84
|
+
|
|
85
|
+
self._parameters = table_manager.parameters
|
|
86
|
+
self._parameter_table = table_manager.parameter_table
|
|
87
|
+
self._spectra_table = table_manager.spectra_table
|
|
88
|
+
self._energies_table = table_manager.energies_table
|
|
89
|
+
|
|
90
|
+
# Instantiate the parameters of the nnx.Module
|
|
91
|
+
for parameter_name, value in table_manager.parameters:
|
|
92
|
+
setattr(self, parameter_name, nnx.Param(value))
|
|
93
|
+
|
|
94
|
+
del table_manager
|
|
95
|
+
|
|
96
|
+
def get_parameter_list(self):
|
|
97
|
+
parameter_list = []
|
|
98
|
+
for parameter_name, _ in self._parameters:
|
|
99
|
+
parameter_list.append(getattr(self, parameter_name))
|
|
100
|
+
|
|
101
|
+
return jnp.asarray(parameter_list, dtype=float)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
"""
|
|
105
|
+
import numpy as np
|
|
106
|
+
from scipy.sparse import csr_matrix
|
|
107
|
+
|
|
108
|
+
def redistribution_matrix(old_e_low, old_e_high, new_e_low, new_e_high):
|
|
109
|
+
new_bins = jnp.stack([new_e_low, new_e_high], axis=1)
|
|
110
|
+
|
|
111
|
+
def scan_body(carry, new_bin):
|
|
112
|
+
new_low, new_high = new_bin
|
|
113
|
+
|
|
114
|
+
# Compute the overlap between the current new bin and all old bins.
|
|
115
|
+
lower_bounds = jnp.maximum(old_e_low, new_low)
|
|
116
|
+
upper_bounds = jnp.minimum(old_e_high, new_high)
|
|
117
|
+
overlap_fraction = jnp.maximum(0, upper_bounds - lower_bounds) / (old_e_high - old_e_low)
|
|
118
|
+
|
|
119
|
+
return carry, overlap_fraction
|
|
120
|
+
|
|
121
|
+
_, matrix = lax.scan(scan_body, None, new_bins)
|
|
122
|
+
return matrix.T
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def redistribute(
|
|
127
|
+
integrated_spectrum: ArrayLike,
|
|
128
|
+
old_e_low: ArrayLike,
|
|
129
|
+
old_e_high: ArrayLike,
|
|
130
|
+
e_low: ArrayLike,
|
|
131
|
+
e_high: ArrayLike,
|
|
132
|
+
) -> ArrayLike:
|
|
133
|
+
"""
|
|
134
|
+
Redistribute the integrated spectrum over the new energy bins.
|
|
135
|
+
|
|
136
|
+
Parameters:
|
|
137
|
+
integrated_spectrum: Integrated spectrum to redistribute.
|
|
138
|
+
old_e_low: Lower bounds of the old energy bins.
|
|
139
|
+
old_e_high: Upper bounds of the old energy bins.
|
|
140
|
+
e_low: Lower bounds of the new energy bins.
|
|
141
|
+
e_high: Upper bounds of the new energy bins.
|
|
142
|
+
"""
|
|
143
|
+
new_bins = jnp.stack([e_low, e_high], axis=1)
|
|
144
|
+
|
|
145
|
+
def scan_body(carry, new_bin):
|
|
146
|
+
new_low, new_high = new_bin
|
|
147
|
+
|
|
148
|
+
# Compute the overlap between the current new bin and all old bins.
|
|
149
|
+
lower_bounds = jnp.maximum(old_e_low, new_low)
|
|
150
|
+
upper_bounds = jnp.minimum(old_e_high, new_high)
|
|
151
|
+
overlap_fraction = jnp.maximum(0, upper_bounds - lower_bounds) / (old_e_high - old_e_low)
|
|
152
|
+
|
|
153
|
+
# Sum over old bins: each old bin contributes its integrated value times
|
|
154
|
+
# the fraction of the new bin that overlaps with it.
|
|
155
|
+
new_intensity = jnp.sum(overlap_fraction * integrated_spectrum)
|
|
156
|
+
|
|
157
|
+
return carry, new_intensity
|
|
158
|
+
|
|
159
|
+
_, redistributed_values = lax.scan(scan_body, None, new_bins)
|
|
160
|
+
return redistributed_values
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class AdditiveTabulated(AdditiveComponent, TabulatedModel):
|
|
164
|
+
"""
|
|
165
|
+
Equivalent of the [`atable`](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XSmodelAtable.html) additive
|
|
166
|
+
component in `xspec`.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, table_path: str):
|
|
170
|
+
super().__init__(table_path)
|
|
171
|
+
|
|
172
|
+
self.norm = nnx.Param(1.0)
|
|
173
|
+
|
|
174
|
+
self._interpolator = RegularGridInterpolator(
|
|
175
|
+
self._parameter_table, self._spectra_table, fill_value=0.0
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _integrate_on_grid(self):
|
|
179
|
+
return self._interpolator(self.get_parameter_list()).squeeze()
|
|
180
|
+
|
|
181
|
+
def integrated_continuum(self, e_low, e_high):
|
|
182
|
+
integrated_spectrum = self._integrate_on_grid()
|
|
183
|
+
return jnp.asarray(self.norm) * redistribute(
|
|
184
|
+
integrated_spectrum, *self._energies_table, e_low, e_high
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class TabulatedModelXarray(nnx.Module):
|
|
189
|
+
"""
|
|
190
|
+
See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/ogip_92_009/ogip_92_009.pdf
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, table_path: str):
|
|
194
|
+
ds = xr.open_dataset(table_path)
|
|
195
|
+
|
|
196
|
+
self._parameters = [dim for dim in ds["spectra"].dims if dim != "energy"]
|
|
197
|
+
self._parameter_table = [np.asarray(ds[key]) for key in self._parameters]
|
|
198
|
+
self._spectra_table = ds["spectra"]
|
|
199
|
+
self._energies_table = (np.asarray(ds["energy_low"]), np.asarray(ds["energy_high"]))
|
|
200
|
+
|
|
201
|
+
# Instantiate the parameters of the nnx.Module
|
|
202
|
+
for parameter_name, parameter_value in zip(self._parameters, self._parameter_table):
|
|
203
|
+
setattr(
|
|
204
|
+
self,
|
|
205
|
+
parameter_name,
|
|
206
|
+
nnx.Param(np.median(parameter_value) + np.random.uniform(-5e-1, +5e-1)),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def get_parameter_list(self):
|
|
210
|
+
parameter_list = []
|
|
211
|
+
for parameter_name in self._parameters:
|
|
212
|
+
parameter_list.append(getattr(self, parameter_name))
|
|
213
|
+
|
|
214
|
+
return jnp.asarray(parameter_list, dtype=float)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class AdditiveTabulatedXarray(AdditiveComponent, TabulatedModelXarray):
|
|
218
|
+
def __init__(self, table_path: str):
|
|
219
|
+
super().__init__(table_path)
|
|
220
|
+
|
|
221
|
+
self.norm = nnx.Param(1.0)
|
|
222
|
+
|
|
223
|
+
interpolator = RegularGridInterpolatorWithGrad(
|
|
224
|
+
self._parameter_table, np.asarray(self._spectra_table)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def callback(pars):
|
|
228
|
+
value, grad = interpolator(pars)
|
|
229
|
+
return np.vstack([value[None, :], grad])
|
|
230
|
+
|
|
231
|
+
result = callback(self.get_parameter_list())
|
|
232
|
+
|
|
233
|
+
out_type = jax.ShapeDtypeStruct(jnp.shape(result), jnp.result_type(result))
|
|
234
|
+
|
|
235
|
+
@jax.custom_jvp
|
|
236
|
+
def spectrum_interpolation(pars):
|
|
237
|
+
result = jax.pure_callback(
|
|
238
|
+
lambda p: callback(p), out_type, pars, vmap_method="legacy_vectorized"
|
|
239
|
+
)
|
|
240
|
+
return result[0, ...]
|
|
241
|
+
|
|
242
|
+
@spectrum_interpolation.defjvp
|
|
243
|
+
def spectrum_interpolation_jvp(primals, tangents):
|
|
244
|
+
pars = primals
|
|
245
|
+
pars_dot = tangents
|
|
246
|
+
|
|
247
|
+
result = jax.pure_callback(
|
|
248
|
+
lambda p: callback(p), out_type, pars, vmap_method="legacy_vectorized"
|
|
249
|
+
)
|
|
250
|
+
value = result[0, ...]
|
|
251
|
+
grad = result[1:, ...]
|
|
252
|
+
|
|
253
|
+
return value, jnp.squeeze(jnp.asarray(pars_dot) @ grad)
|
|
254
|
+
|
|
255
|
+
self._spectrum_interpolation = spectrum_interpolation
|
|
256
|
+
|
|
257
|
+
def _integrate_on_grid(self):
|
|
258
|
+
return self._spectrum_interpolation(jnp.asarray(self.get_parameter_list()))
|
|
259
|
+
|
|
260
|
+
def integrated_continuum(self, e_low, e_high):
|
|
261
|
+
integrated_spectrum = self._integrate_on_grid()
|
|
262
|
+
return jnp.asarray(self.norm) * redistribute(
|
|
263
|
+
integrated_spectrum, *self._energies_table, e_low, e_high
|
|
264
|
+
)
|
jaxspec/fit/__init__.py
ADDED