jaxspec 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,78 @@
1
+ import jax
2
+
3
+ from jax import random
4
+ from numpyro.contrib.nested_sampling import NestedSampler
5
+
6
+ from ..analysis.results import FitResult
7
+ from ..fit._fitter import BayesianModelFitter
8
+
9
+
10
+ class NSFitter(BayesianModelFitter):
11
+ r"""
12
+ A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
13
+ [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
14
+ implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
15
+
16
+ !!! info
17
+ Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
18
+
19
+ """
20
+
21
+ def fit(
22
+ self,
23
+ rng_key: int = 0,
24
+ num_samples: int = 1000,
25
+ num_live_points: int = 1000,
26
+ plot_diagnostics=False,
27
+ termination_kwargs: dict | None = None,
28
+ verbose=True,
29
+ use_transformed_model: bool = True,
30
+ ) -> FitResult:
31
+ """
32
+ Fit the model to the data using the Phantom-Powered nested sampling algorithm.
33
+
34
+ Parameters:
35
+ rng_key: the random key used to initialize the sampler.
36
+ num_samples: the number of samples to draw.
37
+ num_live_points: the number of live points to use at the start of the NS algorithm.
38
+ plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
39
+ termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
40
+ verbose: whether to print the progress of the NS algorithm.
41
+
42
+ Returns:
43
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
44
+ """
45
+
46
+ bayesian_model = self.transformed_numpyro_model
47
+ keys = random.split(random.PRNGKey(rng_key), 4)
48
+
49
+ ns = NestedSampler(
50
+ bayesian_model,
51
+ constructor_kwargs=dict(
52
+ verbose=verbose,
53
+ difficult_model=True,
54
+ max_samples=1e5,
55
+ parameter_estimation=True,
56
+ gradient_guided=True,
57
+ devices=jax.devices(),
58
+ # init_efficiency_threshold=0.01,
59
+ num_live_points=num_live_points,
60
+ ),
61
+ termination_kwargs=termination_kwargs if termination_kwargs else dict(),
62
+ )
63
+
64
+ ns.run(keys[0])
65
+
66
+ if plot_diagnostics:
67
+ ns.diagnostics()
68
+
69
+ posterior = ns.get_samples(keys[1], num_samples=num_samples)
70
+ inference_data = self.build_inference_data(
71
+ posterior, num_chains=1, use_transformed_model=use_transformed_model
72
+ )
73
+
74
+ return FitResult(
75
+ self,
76
+ inference_data,
77
+ background_model=self.background_model,
78
+ )
@@ -0,0 +1,264 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ import xarray as xr
5
+
6
+ from astropy.table import Table
7
+ from flax import nnx
8
+ from jax import lax
9
+ from jax.scipy.interpolate import RegularGridInterpolator
10
+ from jax.typing import ArrayLike
11
+ from tqdm.auto import tqdm
12
+
13
+ from ..model.abc import AdditiveComponent
14
+ from .interpolator import RegularGridInterpolatorWithGrad
15
+
16
+
17
+ class TableManager:
18
+ """
19
+ Handler for the tabulated data from `xspec` tabulated models. The table must follow the format speficied in the
20
+ [`atable`](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XSmodelAtable.html) additive component.
21
+ """
22
+
23
+ def __init__(self, table_path):
24
+ self._table_path = table_path
25
+
26
+ # Read the tables once.
27
+ raw_parameters = Table.read(self._table_path, "PARAMETERS")
28
+ raw_energies = Table.read(self._table_path, "ENERGIES")
29
+ raw_spectra = Table.read(self._table_path, "SPECTRA")
30
+
31
+ # Build and wrap the parameter table.
32
+ parameter_table = tuple(np.asarray(row["VALUE"], dtype=float) for row in raw_parameters)
33
+ parameters = []
34
+ for row in raw_parameters:
35
+ parameter_name = row["NAME"].rstrip()
36
+ parameters.append((parameter_name, float(row["INITIAL"])))
37
+
38
+ # Process and wrap the energies table.
39
+ energy_low = np.asarray(raw_energies["ENERG_LO"].value, dtype=float)
40
+ energy_high = np.asarray(raw_energies["ENERG_HI"].value, dtype=float)
41
+ energies_table = np.vstack((energy_low, energy_high))
42
+
43
+ # Compute the parameter shape
44
+ parameter_shape = tuple()
45
+ for parameter in parameter_table:
46
+ parameter_shape += parameter.shape
47
+
48
+ # Build and wrap the spectra table.
49
+ total_shape = parameter_shape + energy_low.shape
50
+
51
+ spectra_table = np.empty(total_shape, dtype=float)
52
+ for idx, par_indexes in enumerate(np.ndindex(*parameter_shape)):
53
+ spectra_table[par_indexes] = np.asarray(raw_spectra["INTPSPEC"][idx], dtype=float)
54
+
55
+ self.parameters = parameters
56
+ self.energies_table = energies_table
57
+ self.parameter_table = parameter_table
58
+ self.parameter_shape = parameter_shape
59
+ self.spectra_table = spectra_table
60
+
61
+ def check_proper_ordering(self):
62
+ """
63
+ Assert that the parameter ordering in the spectra table is consistent with the parameter table.
64
+ """
65
+
66
+ spectra_table = Table.read(self._table_path, "SPECTRA")
67
+ pars_mesh = np.meshgrid(*self.parameter_table, indexing="ij", sparse=False)
68
+
69
+ for idx, par_indexes in enumerate(
70
+ tqdm(np.ndindex(*self.parameter_shape), total=len(spectra_table))
71
+ ):
72
+ expected = np.asarray([par[par_indexes] for par in pars_mesh])
73
+ obtained = jnp.asarray(spectra_table["PARAMVAL"][idx], dtype=float)
74
+ assert jnp.allclose(expected, obtained), "Parameter ordering mismatch"
75
+
76
+
77
+ class TabulatedModel(nnx.Module):
78
+ """
79
+ See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/ogip_92_009/ogip_92_009.pdf
80
+ """
81
+
82
+ def __init__(self, table_path: str):
83
+ table_manager = TableManager(table_path)
84
+
85
+ self._parameters = table_manager.parameters
86
+ self._parameter_table = table_manager.parameter_table
87
+ self._spectra_table = table_manager.spectra_table
88
+ self._energies_table = table_manager.energies_table
89
+
90
+ # Instantiate the parameters of the nnx.Module
91
+ for parameter_name, value in table_manager.parameters:
92
+ setattr(self, parameter_name, nnx.Param(value))
93
+
94
+ del table_manager
95
+
96
+ def get_parameter_list(self):
97
+ parameter_list = []
98
+ for parameter_name, _ in self._parameters:
99
+ parameter_list.append(getattr(self, parameter_name))
100
+
101
+ return jnp.asarray(parameter_list, dtype=float)
102
+
103
+
104
+ """
105
+ import numpy as np
106
+ from scipy.sparse import csr_matrix
107
+
108
+ def redistribution_matrix(old_e_low, old_e_high, new_e_low, new_e_high):
109
+ new_bins = jnp.stack([new_e_low, new_e_high], axis=1)
110
+
111
+ def scan_body(carry, new_bin):
112
+ new_low, new_high = new_bin
113
+
114
+ # Compute the overlap between the current new bin and all old bins.
115
+ lower_bounds = jnp.maximum(old_e_low, new_low)
116
+ upper_bounds = jnp.minimum(old_e_high, new_high)
117
+ overlap_fraction = jnp.maximum(0, upper_bounds - lower_bounds) / (old_e_high - old_e_low)
118
+
119
+ return carry, overlap_fraction
120
+
121
+ _, matrix = lax.scan(scan_body, None, new_bins)
122
+ return matrix.T
123
+ """
124
+
125
+
126
+ def redistribute(
127
+ integrated_spectrum: ArrayLike,
128
+ old_e_low: ArrayLike,
129
+ old_e_high: ArrayLike,
130
+ e_low: ArrayLike,
131
+ e_high: ArrayLike,
132
+ ) -> ArrayLike:
133
+ """
134
+ Redistribute the integrated spectrum over the new energy bins.
135
+
136
+ Parameters:
137
+ integrated_spectrum: Integrated spectrum to redistribute.
138
+ old_e_low: Lower bounds of the old energy bins.
139
+ old_e_high: Upper bounds of the old energy bins.
140
+ e_low: Lower bounds of the new energy bins.
141
+ e_high: Upper bounds of the new energy bins.
142
+ """
143
+ new_bins = jnp.stack([e_low, e_high], axis=1)
144
+
145
+ def scan_body(carry, new_bin):
146
+ new_low, new_high = new_bin
147
+
148
+ # Compute the overlap between the current new bin and all old bins.
149
+ lower_bounds = jnp.maximum(old_e_low, new_low)
150
+ upper_bounds = jnp.minimum(old_e_high, new_high)
151
+ overlap_fraction = jnp.maximum(0, upper_bounds - lower_bounds) / (old_e_high - old_e_low)
152
+
153
+ # Sum over old bins: each old bin contributes its integrated value times
154
+ # the fraction of the new bin that overlaps with it.
155
+ new_intensity = jnp.sum(overlap_fraction * integrated_spectrum)
156
+
157
+ return carry, new_intensity
158
+
159
+ _, redistributed_values = lax.scan(scan_body, None, new_bins)
160
+ return redistributed_values
161
+
162
+
163
+ class AdditiveTabulated(AdditiveComponent, TabulatedModel):
164
+ """
165
+ Equivalent of the [`atable`](https://heasarc.gsfc.nasa.gov/docs/xanadu/xspec/manual/XSmodelAtable.html) additive
166
+ component in `xspec`.
167
+ """
168
+
169
+ def __init__(self, table_path: str):
170
+ super().__init__(table_path)
171
+
172
+ self.norm = nnx.Param(1.0)
173
+
174
+ self._interpolator = RegularGridInterpolator(
175
+ self._parameter_table, self._spectra_table, fill_value=0.0
176
+ )
177
+
178
+ def _integrate_on_grid(self):
179
+ return self._interpolator(self.get_parameter_list()).squeeze()
180
+
181
+ def integrated_continuum(self, e_low, e_high):
182
+ integrated_spectrum = self._integrate_on_grid()
183
+ return jnp.asarray(self.norm) * redistribute(
184
+ integrated_spectrum, *self._energies_table, e_low, e_high
185
+ )
186
+
187
+
188
+ class TabulatedModelXarray(nnx.Module):
189
+ """
190
+ See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/ogip_92_009/ogip_92_009.pdf
191
+ """
192
+
193
+ def __init__(self, table_path: str):
194
+ ds = xr.open_dataset(table_path)
195
+
196
+ self._parameters = [dim for dim in ds["spectra"].dims if dim != "energy"]
197
+ self._parameter_table = [np.asarray(ds[key]) for key in self._parameters]
198
+ self._spectra_table = ds["spectra"]
199
+ self._energies_table = (np.asarray(ds["energy_low"]), np.asarray(ds["energy_high"]))
200
+
201
+ # Instantiate the parameters of the nnx.Module
202
+ for parameter_name, parameter_value in zip(self._parameters, self._parameter_table):
203
+ setattr(
204
+ self,
205
+ parameter_name,
206
+ nnx.Param(np.median(parameter_value) + np.random.uniform(-5e-1, +5e-1)),
207
+ )
208
+
209
+ def get_parameter_list(self):
210
+ parameter_list = []
211
+ for parameter_name in self._parameters:
212
+ parameter_list.append(getattr(self, parameter_name))
213
+
214
+ return jnp.asarray(parameter_list, dtype=float)
215
+
216
+
217
+ class AdditiveTabulatedXarray(AdditiveComponent, TabulatedModelXarray):
218
+ def __init__(self, table_path: str):
219
+ super().__init__(table_path)
220
+
221
+ self.norm = nnx.Param(1.0)
222
+
223
+ interpolator = RegularGridInterpolatorWithGrad(
224
+ self._parameter_table, np.asarray(self._spectra_table)
225
+ )
226
+
227
+ def callback(pars):
228
+ value, grad = interpolator(pars)
229
+ return np.vstack([value[None, :], grad])
230
+
231
+ result = callback(self.get_parameter_list())
232
+
233
+ out_type = jax.ShapeDtypeStruct(jnp.shape(result), jnp.result_type(result))
234
+
235
+ @jax.custom_jvp
236
+ def spectrum_interpolation(pars):
237
+ result = jax.pure_callback(
238
+ lambda p: callback(p), out_type, pars, vmap_method="legacy_vectorized"
239
+ )
240
+ return result[0, ...]
241
+
242
+ @spectrum_interpolation.defjvp
243
+ def spectrum_interpolation_jvp(primals, tangents):
244
+ pars = primals
245
+ pars_dot = tangents
246
+
247
+ result = jax.pure_callback(
248
+ lambda p: callback(p), out_type, pars, vmap_method="legacy_vectorized"
249
+ )
250
+ value = result[0, ...]
251
+ grad = result[1:, ...]
252
+
253
+ return value, jnp.squeeze(jnp.asarray(pars_dot) @ grad)
254
+
255
+ self._spectrum_interpolation = spectrum_interpolation
256
+
257
+ def _integrate_on_grid(self):
258
+ return self._spectrum_interpolation(jnp.asarray(self.get_parameter_list()))
259
+
260
+ def integrated_continuum(self, e_low, e_high):
261
+ integrated_spectrum = self._integrate_on_grid()
262
+ return jnp.asarray(self.norm) * redistribute(
263
+ integrated_spectrum, *self._energies_table, e_low, e_high
264
+ )
@@ -0,0 +1,3 @@
1
+ from ._bayesian_model import BayesianModel
2
+ from ._build_model import TiedParameter, build_prior, forward_model
3
+ from ._fitter import MCMCFitter, VIFitter