jaxspec 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/__init__.py +1 -1
- jaxspec/analysis/compare.py +3 -3
- jaxspec/analysis/results.py +239 -110
- jaxspec/data/instrument.py +0 -2
- jaxspec/data/ogip.py +18 -0
- jaxspec/data/util.py +11 -3
- jaxspec/fit.py +166 -72
- jaxspec/model/_additive/__init__.py +0 -0
- jaxspec/model/_additive/apec.py +377 -0
- jaxspec/model/_additive/apec_loaders.py +90 -0
- jaxspec/model/abc.py +55 -7
- jaxspec/model/additive.py +2 -51
- jaxspec/tables/abundances.dat +31 -0
- jaxspec/util/abundance.py +111 -0
- jaxspec/util/integrate.py +5 -4
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/METADATA +5 -3
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/RECORD +19 -14
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.4.dist-info → jaxspec-0.0.6.dist-info}/WHEEL +0 -0
jaxspec/fit.py
CHANGED
|
@@ -4,11 +4,12 @@ import numpyro
|
|
|
4
4
|
import arviz as az
|
|
5
5
|
import jax
|
|
6
6
|
from typing import Callable, TypeVar
|
|
7
|
-
from abc import ABC
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
8
|
from jax import random
|
|
9
9
|
from jax.tree_util import tree_map
|
|
10
|
+
from jax.flatten_util import ravel_pytree
|
|
10
11
|
from jax.experimental.sparse import BCSR
|
|
11
|
-
from .analysis.results import
|
|
12
|
+
from .analysis.results import FitResult
|
|
12
13
|
from .model.abc import SpectralModel
|
|
13
14
|
from .data import ObsConfiguration
|
|
14
15
|
from .model.background import BackgroundModel
|
|
@@ -17,13 +18,15 @@ from numpyro.distributions import Distribution, TransformedDistribution
|
|
|
17
18
|
from numpyro.distributions import Poisson
|
|
18
19
|
from jax.typing import ArrayLike
|
|
19
20
|
from numpyro.infer.reparam import TransformReparam
|
|
21
|
+
from numpyro.infer.util import initialize_model
|
|
22
|
+
from jax.random import PRNGKey
|
|
23
|
+
import jaxopt
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
T = TypeVar("T")
|
|
23
27
|
|
|
24
28
|
|
|
25
|
-
class HaikuDict(dict[str, dict[str, T]]):
|
|
26
|
-
...
|
|
29
|
+
class HaikuDict(dict[str, dict[str, T]]): ...
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple = ()):
|
|
@@ -32,7 +35,8 @@ def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple
|
|
|
32
35
|
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
33
36
|
match sample:
|
|
34
37
|
case Distribution():
|
|
35
|
-
parameters[m][n] = numpyro.sample(f"{m}_{n}", sample
|
|
38
|
+
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{m}_{n}", sample)
|
|
39
|
+
# parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape)) build a free parameter for each obs
|
|
36
40
|
case float() | ArrayLike():
|
|
37
41
|
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
38
42
|
case _:
|
|
@@ -54,7 +58,7 @@ def build_numpyro_model(
|
|
|
54
58
|
|
|
55
59
|
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
56
60
|
bkg_countrate = background_model.numpyro_model(
|
|
57
|
-
obs.out_energies, obs.folded_background.data, name=
|
|
61
|
+
obs.out_energies, obs.folded_background.data, name="bkg_" + name, observed=observed
|
|
58
62
|
)
|
|
59
63
|
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
60
64
|
raise ValueError("Trying to fit a background model but no background is linked to this observation")
|
|
@@ -66,9 +70,9 @@ def build_numpyro_model(
|
|
|
66
70
|
countrate = obs_model(prior_params)
|
|
67
71
|
|
|
68
72
|
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
|
|
69
|
-
with numpyro.plate(
|
|
73
|
+
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
70
74
|
numpyro.sample(
|
|
71
|
-
|
|
75
|
+
"obs_" + name,
|
|
72
76
|
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
73
77
|
obs=obs.folded_counts.data if observed else None,
|
|
74
78
|
)
|
|
@@ -76,6 +80,25 @@ def build_numpyro_model(
|
|
|
76
80
|
return numpro_model
|
|
77
81
|
|
|
78
82
|
|
|
83
|
+
def filter_inference_data(inference_data, observation_container, background_model=None) -> az.InferenceData:
|
|
84
|
+
predictive_parameters = []
|
|
85
|
+
|
|
86
|
+
for key, value in observation_container.items():
|
|
87
|
+
if background_model is not None:
|
|
88
|
+
predictive_parameters.append(f"obs_{key}")
|
|
89
|
+
predictive_parameters.append(f"bkg_{key}")
|
|
90
|
+
else:
|
|
91
|
+
predictive_parameters.append(f"obs_{key}")
|
|
92
|
+
|
|
93
|
+
inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
|
|
94
|
+
|
|
95
|
+
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
96
|
+
inference_data.posterior = inference_data.posterior[parameters]
|
|
97
|
+
inference_data.prior = inference_data.prior[parameters]
|
|
98
|
+
|
|
99
|
+
return inference_data
|
|
100
|
+
|
|
101
|
+
|
|
79
102
|
class CountForwardModel(hk.Module):
|
|
80
103
|
"""
|
|
81
104
|
A haiku module which allows to build the function that simulates the measured counts
|
|
@@ -97,27 +120,87 @@ class CountForwardModel(hk.Module):
|
|
|
97
120
|
Compute the count functions for a given observation.
|
|
98
121
|
"""
|
|
99
122
|
|
|
100
|
-
expected_counts = self.transfer_matrix @ self.model(parameters, *self.energies)
|
|
123
|
+
expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
|
|
101
124
|
|
|
102
125
|
return jnp.clip(expected_counts, a_min=1e-6)
|
|
103
126
|
|
|
104
127
|
|
|
105
|
-
class
|
|
128
|
+
class ModelFitter(ABC):
|
|
106
129
|
"""
|
|
107
130
|
Abstract class to fit a model to a given set of observation.
|
|
108
131
|
"""
|
|
109
132
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
model: SpectralModel,
|
|
136
|
+
observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
|
|
137
|
+
background_model: BackgroundModel = None,
|
|
138
|
+
sparsify_matrix: bool = False,
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Initialize the fitter.
|
|
117
142
|
|
|
118
|
-
|
|
143
|
+
Parameters:
|
|
144
|
+
model: the spectral model to fit.
|
|
145
|
+
observations: the observations to fit the model to.
|
|
146
|
+
background_model: the background model to fit.
|
|
147
|
+
sparsify_matrix: whether to sparsify the transfer matrix.
|
|
148
|
+
"""
|
|
119
149
|
self.model = model
|
|
150
|
+
self._observations = observations
|
|
151
|
+
self.background_model = background_model
|
|
120
152
|
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
153
|
+
self.sparse = sparsify_matrix
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def _observation_container(self) -> dict[str, ObsConfiguration]:
|
|
157
|
+
"""
|
|
158
|
+
The observations used in the fit as a dictionary of observations.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
if isinstance(self._observations, dict):
|
|
162
|
+
return self._observations
|
|
163
|
+
|
|
164
|
+
elif isinstance(self._observations, list):
|
|
165
|
+
return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
|
|
166
|
+
|
|
167
|
+
elif isinstance(self._observations, ObsConfiguration):
|
|
168
|
+
return {"data": self._observations}
|
|
169
|
+
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
172
|
+
|
|
173
|
+
def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
|
|
174
|
+
"""
|
|
175
|
+
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
176
|
+
|
|
177
|
+
Parameters:
|
|
178
|
+
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
A model function that can be used with numpyro.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def model(observed=True):
|
|
185
|
+
prior_params = build_prior(prior_distributions, expand_shape=(len(self._observation_container),))
|
|
186
|
+
|
|
187
|
+
for i, (key, observation) in enumerate(self._observation_container.items()):
|
|
188
|
+
params = tree_map(lambda x: x[i], prior_params)
|
|
189
|
+
|
|
190
|
+
obs_model = build_numpyro_model(observation, self.model, self.background_model, name=key, sparse=self.sparse)
|
|
191
|
+
obs_model(params, observed=observed)
|
|
192
|
+
|
|
193
|
+
return model
|
|
194
|
+
|
|
195
|
+
@abstractmethod
|
|
196
|
+
def fit(self, prior_distributions: HaikuDict[Distribution], **kwargs) -> FitResult: ...
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class BayesianFitter(ModelFitter):
|
|
200
|
+
"""
|
|
201
|
+
A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
|
|
202
|
+
from numpyro to perform the inference on the model parameters.
|
|
203
|
+
"""
|
|
121
204
|
|
|
122
205
|
def fit(
|
|
123
206
|
self,
|
|
@@ -128,11 +211,11 @@ class BayesianModelAbstract(ABC):
|
|
|
128
211
|
num_samples: int = 1000,
|
|
129
212
|
max_tree_depth: int = 10,
|
|
130
213
|
target_accept_prob: float = 0.8,
|
|
131
|
-
dense_mass=False,
|
|
214
|
+
dense_mass: bool = False,
|
|
132
215
|
mcmc_kwargs: dict = {},
|
|
133
|
-
) ->
|
|
216
|
+
) -> FitResult:
|
|
134
217
|
"""
|
|
135
|
-
Fit the model to the data using NUTS sampler from numpyro.
|
|
218
|
+
Fit the model to the data using NUTS sampler from numpyro.
|
|
136
219
|
|
|
137
220
|
Parameters:
|
|
138
221
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
@@ -146,7 +229,7 @@ class BayesianModelAbstract(ABC):
|
|
|
146
229
|
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
147
230
|
|
|
148
231
|
Returns:
|
|
149
|
-
A [`
|
|
232
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
150
233
|
"""
|
|
151
234
|
|
|
152
235
|
transform_dict = {}
|
|
@@ -174,80 +257,91 @@ class BayesianModelAbstract(ABC):
|
|
|
174
257
|
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
175
258
|
inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
|
|
176
259
|
|
|
177
|
-
|
|
178
|
-
inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
|
|
179
|
-
|
|
180
|
-
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
181
|
-
inference_data.posterior = inference_data.posterior[parameters]
|
|
182
|
-
inference_data.prior = inference_data.prior[parameters]
|
|
260
|
+
inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
|
|
183
261
|
|
|
184
|
-
return
|
|
262
|
+
return FitResult(
|
|
185
263
|
self.model,
|
|
186
|
-
self.
|
|
264
|
+
self._observation_container,
|
|
187
265
|
inference_data,
|
|
188
|
-
mcmc.get_samples(),
|
|
189
266
|
self.model.params,
|
|
190
267
|
background_model=self.background_model,
|
|
191
268
|
)
|
|
192
269
|
|
|
193
270
|
|
|
194
|
-
class
|
|
271
|
+
class MinimizationFitter(ModelFitter):
|
|
195
272
|
"""
|
|
196
|
-
|
|
273
|
+
A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
|
|
274
|
+
algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
|
|
275
|
+
Hessian of the log-likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
|
|
276
|
+
numpyro.
|
|
197
277
|
"""
|
|
198
278
|
|
|
199
|
-
def
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
|
|
279
|
+
def fit(
|
|
280
|
+
self,
|
|
281
|
+
prior_distributions: HaikuDict[Distribution],
|
|
282
|
+
rng_key: int = 0,
|
|
283
|
+
num_iter_max: int = 10_000,
|
|
284
|
+
num_samples: int = 1_000,
|
|
285
|
+
) -> FitResult:
|
|
207
286
|
"""
|
|
208
|
-
|
|
209
|
-
to fit the model using numpyro's various samplers.
|
|
287
|
+
Fit the model to the data using L-BFGS algorithm.
|
|
210
288
|
|
|
211
289
|
Parameters:
|
|
212
290
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
prior_params = build_prior(prior_distributions)
|
|
217
|
-
obs_model = build_numpyro_model(self.observation, self.model, self.background_model, sparse=self.sparse)
|
|
218
|
-
obs_model(prior_params, observed=observed)
|
|
219
|
-
|
|
220
|
-
return model
|
|
291
|
+
rng_key: the random key used to initialize the sampler.
|
|
292
|
+
num_iter_max: the maximum number of iteration in the minimization algorithm.
|
|
293
|
+
num_samples: the number of sample to draw from the best-fit covariance.
|
|
221
294
|
|
|
295
|
+
Returns:
|
|
296
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
297
|
+
"""
|
|
222
298
|
|
|
223
|
-
|
|
224
|
-
class MultipleObservationMCMC(BayesianModelAbstract):
|
|
299
|
+
bayesian_model = self.numpyro_model(prior_distributions)
|
|
225
300
|
|
|
226
|
-
|
|
301
|
+
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
|
|
302
|
+
PRNGKey(0),
|
|
303
|
+
bayesian_model,
|
|
304
|
+
model_args=tuple(),
|
|
305
|
+
dynamic_args=True, # <- this is important!
|
|
306
|
+
)
|
|
227
307
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
308
|
+
# get negative log-density from the potential function
|
|
309
|
+
@jax.jit
|
|
310
|
+
def nll_fn(position):
|
|
311
|
+
func = potential_fn()
|
|
312
|
+
return func(position)
|
|
232
313
|
|
|
233
|
-
|
|
314
|
+
solver = jaxopt.LBFGS(fun=nll_fn, maxiter=10_000)
|
|
315
|
+
params, state = solver.run(param_info.z)
|
|
316
|
+
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
234
317
|
|
|
235
|
-
|
|
318
|
+
value_flat, unflatten_fun = ravel_pytree(params)
|
|
319
|
+
covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
|
|
236
320
|
|
|
237
|
-
|
|
321
|
+
samples_flat = jax.random.multivariate_normal(keys[0], value_flat, covariance, shape=(num_samples,))
|
|
322
|
+
samples = jax.vmap(unflatten_fun)(samples_flat.block_until_ready())
|
|
323
|
+
posterior_samples = postprocess_fn()(samples)
|
|
238
324
|
|
|
239
|
-
|
|
325
|
+
posterior_predictive = Predictive(bayesian_model, posterior_samples)(keys[1], observed=False)
|
|
326
|
+
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
327
|
+
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
240
328
|
|
|
241
|
-
|
|
329
|
+
def sanitize_chain(chain):
|
|
330
|
+
return tree_map(lambda x: x[None, ...], chain)
|
|
242
331
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
332
|
+
inference_data = az.from_dict(
|
|
333
|
+
sanitize_chain(posterior_samples),
|
|
334
|
+
prior=sanitize_chain(prior),
|
|
335
|
+
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
336
|
+
log_likelihood=sanitize_chain(log_likelihood),
|
|
337
|
+
)
|
|
249
338
|
|
|
250
|
-
|
|
339
|
+
inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
|
|
251
340
|
|
|
252
|
-
return
|
|
253
|
-
|
|
341
|
+
return FitResult(
|
|
342
|
+
self.model,
|
|
343
|
+
self._observation_container,
|
|
344
|
+
inference_data,
|
|
345
|
+
self.model.params,
|
|
346
|
+
background_model=self.background_model,
|
|
347
|
+
)
|
|
File without changes
|