jaxspec 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +250 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +29 -20
- jaxspec/fit.py +329 -81
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +27 -13
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +20 -25
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +15 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.7.dist-info}/METADATA +11 -8
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.7.dist-info}/RECORD +19 -21
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.7.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.7.dist-info}/WHEEL +0 -0
jaxspec/fit.py
CHANGED
|
@@ -1,67 +1,84 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import arviz as az
|
|
1
6
|
import haiku as hk
|
|
7
|
+
import jax
|
|
2
8
|
import jax.numpy as jnp
|
|
3
9
|
import numpyro
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
from typing import Callable, TypeVar
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
10
|
+
import optimistix as optx
|
|
11
|
+
|
|
8
12
|
from jax import random
|
|
9
|
-
from jax.
|
|
13
|
+
from jax.experimental.sparse import BCOO
|
|
10
14
|
from jax.flatten_util import ravel_pytree
|
|
11
|
-
from jax.
|
|
12
|
-
from .
|
|
13
|
-
from .model.abc import SpectralModel
|
|
14
|
-
from .data import ObsConfiguration
|
|
15
|
-
from .model.background import BackgroundModel
|
|
16
|
-
from numpyro.infer import MCMC, NUTS, Predictive
|
|
17
|
-
from numpyro.distributions import Distribution, TransformedDistribution
|
|
18
|
-
from numpyro.distributions import Poisson
|
|
15
|
+
from jax.random import PRNGKey
|
|
16
|
+
from jax.tree_util import tree_map
|
|
19
17
|
from jax.typing import ArrayLike
|
|
18
|
+
from numpyro.contrib.nested_sampling import NestedSampler
|
|
19
|
+
from numpyro.distributions import Distribution, Poisson, TransformedDistribution
|
|
20
|
+
from numpyro.infer import MCMC, NUTS, Predictive
|
|
21
|
+
from numpyro.infer.initialization import init_to_value
|
|
20
22
|
from numpyro.infer.reparam import TransformReparam
|
|
21
|
-
from numpyro.infer.util import
|
|
22
|
-
from
|
|
23
|
-
import jaxopt
|
|
23
|
+
from numpyro.infer.util import constrain_fn
|
|
24
|
+
from scipy.stats import Covariance, multivariate_normal
|
|
24
25
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
from .analysis.results import FitResult
|
|
27
|
+
from .data import ObsConfiguration
|
|
28
|
+
from .model.abc import SpectralModel
|
|
29
|
+
from .model.background import BackgroundModel
|
|
30
|
+
from .util import catchtime, sample_prior
|
|
31
|
+
from .util.typing import PriorDictModel, PriorDictType
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def build_prior(prior:
|
|
34
|
+
def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
|
|
35
|
+
"""
|
|
36
|
+
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
37
|
+
Must be used within a numpyro model.
|
|
38
|
+
"""
|
|
33
39
|
parameters = dict(hk.data_structures.to_haiku_dict(prior))
|
|
34
40
|
|
|
35
41
|
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
36
42
|
match sample:
|
|
37
43
|
case Distribution():
|
|
38
|
-
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(
|
|
39
|
-
|
|
40
|
-
|
|
44
|
+
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(
|
|
45
|
+
f"{prefix}{m}_{n}", sample
|
|
46
|
+
)
|
|
47
|
+
case ArrayLike():
|
|
41
48
|
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
42
49
|
case _:
|
|
43
|
-
raise ValueError(
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
|
|
52
|
+
)
|
|
44
53
|
|
|
45
54
|
return parameters
|
|
46
55
|
|
|
47
56
|
|
|
48
|
-
def
|
|
57
|
+
def build_numpyro_model_for_single_obs(
|
|
49
58
|
obs: ObsConfiguration,
|
|
50
59
|
model: SpectralModel,
|
|
51
60
|
background_model: BackgroundModel,
|
|
52
61
|
name: str = "",
|
|
53
62
|
sparse: bool = False,
|
|
54
63
|
) -> Callable:
|
|
55
|
-
|
|
64
|
+
"""
|
|
65
|
+
Build a numpyro model for a given observation and spectral model.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def numpyro_model(prior_params, observed=True):
|
|
56
69
|
# prior_params = build_prior(prior_distributions, name=name)
|
|
57
|
-
transformed_model = hk.without_apply_rng(
|
|
70
|
+
transformed_model = hk.without_apply_rng(
|
|
71
|
+
hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
|
|
72
|
+
)
|
|
58
73
|
|
|
59
74
|
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
60
75
|
bkg_countrate = background_model.numpyro_model(
|
|
61
|
-
obs
|
|
76
|
+
obs, model, name="bkg_" + name, observed=observed
|
|
62
77
|
)
|
|
63
78
|
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
64
|
-
raise ValueError(
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Trying to fit a background model but no background is linked to this observation"
|
|
81
|
+
)
|
|
65
82
|
|
|
66
83
|
else:
|
|
67
84
|
bkg_countrate = 0.0
|
|
@@ -73,14 +90,16 @@ def build_numpyro_model(
|
|
|
73
90
|
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
74
91
|
numpyro.sample(
|
|
75
92
|
"obs_" + name,
|
|
76
|
-
Poisson(countrate + bkg_countrate
|
|
93
|
+
Poisson(countrate + bkg_countrate * obs.folded_backratio.data),
|
|
77
94
|
obs=obs.folded_counts.data if observed else None,
|
|
78
95
|
)
|
|
79
96
|
|
|
80
|
-
return
|
|
97
|
+
return numpyro_model
|
|
81
98
|
|
|
82
99
|
|
|
83
|
-
def filter_inference_data(
|
|
100
|
+
def filter_inference_data(
|
|
101
|
+
inference_data, observation_container, background_model=None
|
|
102
|
+
) -> az.InferenceData:
|
|
84
103
|
predictive_parameters = []
|
|
85
104
|
|
|
86
105
|
for key, value in observation_container.items():
|
|
@@ -109,8 +128,12 @@ class CountForwardModel(hk.Module):
|
|
|
109
128
|
self.model = model
|
|
110
129
|
self.energies = jnp.asarray(folding.in_energies)
|
|
111
130
|
|
|
112
|
-
if
|
|
113
|
-
|
|
131
|
+
if (
|
|
132
|
+
sparse
|
|
133
|
+
): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
134
|
+
self.transfer_matrix = BCOO.from_scipy_sparse(
|
|
135
|
+
folding.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
136
|
+
)
|
|
114
137
|
|
|
115
138
|
else:
|
|
116
139
|
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
@@ -140,7 +163,8 @@ class ModelFitter(ABC):
|
|
|
140
163
|
"""
|
|
141
164
|
Initialize the fitter.
|
|
142
165
|
|
|
143
|
-
Parameters
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
144
168
|
model: the spectral model to fit.
|
|
145
169
|
observations: the observations to fit the model to.
|
|
146
170
|
background_model: the background model to fit.
|
|
@@ -170,30 +194,61 @@ class ModelFitter(ABC):
|
|
|
170
194
|
else:
|
|
171
195
|
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
172
196
|
|
|
173
|
-
def numpyro_model(self, prior_distributions:
|
|
197
|
+
def numpyro_model(self, prior_distributions: PriorDictType | Callable) -> Callable:
|
|
174
198
|
"""
|
|
175
199
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
176
200
|
|
|
177
|
-
Parameters
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
178
203
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
179
204
|
|
|
180
205
|
Returns:
|
|
206
|
+
-------
|
|
181
207
|
A model function that can be used with numpyro.
|
|
182
208
|
"""
|
|
183
209
|
|
|
210
|
+
if not callable(prior_distributions):
|
|
211
|
+
# Validate the entry with pydantic
|
|
212
|
+
prior_distributions = PriorDictModel(nested_dict=prior_distributions).nested_dict
|
|
213
|
+
|
|
214
|
+
def prior_distributions_func():
|
|
215
|
+
return build_prior(
|
|
216
|
+
prior_distributions, expand_shape=(len(self._observation_container),)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
else:
|
|
220
|
+
prior_distributions_func = prior_distributions
|
|
221
|
+
|
|
184
222
|
def model(observed=True):
|
|
185
|
-
prior_params =
|
|
223
|
+
prior_params = prior_distributions_func()
|
|
186
224
|
|
|
225
|
+
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
187
226
|
for i, (key, observation) in enumerate(self._observation_container.items()):
|
|
227
|
+
# We expect that prior_params contains an array of parameters for each observation
|
|
228
|
+
# They can be identical or different for each observation
|
|
188
229
|
params = tree_map(lambda x: x[i], prior_params)
|
|
189
230
|
|
|
190
|
-
obs_model =
|
|
231
|
+
obs_model = build_numpyro_model_for_single_obs(
|
|
232
|
+
observation, self.model, self.background_model, name=key, sparse=self.sparse
|
|
233
|
+
)
|
|
234
|
+
|
|
191
235
|
obs_model(params, observed=observed)
|
|
192
236
|
|
|
193
237
|
return model
|
|
194
238
|
|
|
239
|
+
def transformed_numpyro_model(self, prior_distributions: PriorDictType) -> Callable:
|
|
240
|
+
transform_dict = {}
|
|
241
|
+
|
|
242
|
+
for m, n, val in hk.data_structures.traverse(prior_distributions):
|
|
243
|
+
if isinstance(val, TransformedDistribution):
|
|
244
|
+
transform_dict[f"{m}_{n}"] = TransformReparam()
|
|
245
|
+
|
|
246
|
+
return numpyro.handlers.reparam(
|
|
247
|
+
self.numpyro_model(prior_distributions), config=transform_dict
|
|
248
|
+
)
|
|
249
|
+
|
|
195
250
|
@abstractmethod
|
|
196
|
-
def fit(self, prior_distributions:
|
|
251
|
+
def fit(self, prior_distributions: PriorDictType, **kwargs) -> FitResult: ...
|
|
197
252
|
|
|
198
253
|
|
|
199
254
|
class BayesianFitter(ModelFitter):
|
|
@@ -204,20 +259,22 @@ class BayesianFitter(ModelFitter):
|
|
|
204
259
|
|
|
205
260
|
def fit(
|
|
206
261
|
self,
|
|
207
|
-
prior_distributions:
|
|
262
|
+
prior_distributions: PriorDictType,
|
|
208
263
|
rng_key: int = 0,
|
|
209
|
-
num_chains: int =
|
|
264
|
+
num_chains: int = len(jax.devices()),
|
|
210
265
|
num_warmup: int = 1000,
|
|
211
266
|
num_samples: int = 1000,
|
|
212
267
|
max_tree_depth: int = 10,
|
|
213
268
|
target_accept_prob: float = 0.8,
|
|
214
269
|
dense_mass: bool = False,
|
|
270
|
+
kernel_kwargs: dict = {},
|
|
215
271
|
mcmc_kwargs: dict = {},
|
|
216
272
|
) -> FitResult:
|
|
217
273
|
"""
|
|
218
274
|
Fit the model to the data using NUTS sampler from numpyro.
|
|
219
275
|
|
|
220
|
-
Parameters
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
221
278
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
222
279
|
rng_key: the random key used to initialize the sampler.
|
|
223
280
|
num_chains: the number of chains to run.
|
|
@@ -229,14 +286,12 @@ class BayesianFitter(ModelFitter):
|
|
|
229
286
|
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
230
287
|
|
|
231
288
|
Returns:
|
|
289
|
+
-------
|
|
232
290
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
233
291
|
"""
|
|
234
292
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
for m, n, val in hk.data_structures.traverse(prior_distributions):
|
|
238
|
-
if isinstance(val, TransformedDistribution):
|
|
239
|
-
transform_dict[f"{m}_{n}"] = TransformReparam()
|
|
293
|
+
bayesian_model = self.transformed_numpyro_model(prior_distributions)
|
|
294
|
+
# bayesian_model = self.numpyro_model(prior_distributions)
|
|
240
295
|
|
|
241
296
|
chain_kwargs = {
|
|
242
297
|
"num_warmup": num_warmup,
|
|
@@ -244,20 +299,29 @@ class BayesianFitter(ModelFitter):
|
|
|
244
299
|
"num_chains": num_chains,
|
|
245
300
|
}
|
|
246
301
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
302
|
+
kernel = NUTS(
|
|
303
|
+
bayesian_model,
|
|
304
|
+
max_tree_depth=max_tree_depth,
|
|
305
|
+
target_accept_prob=target_accept_prob,
|
|
306
|
+
dense_mass=dense_mass,
|
|
307
|
+
**kernel_kwargs,
|
|
308
|
+
)
|
|
250
309
|
|
|
251
310
|
mcmc = MCMC(kernel, **(chain_kwargs | mcmc_kwargs))
|
|
252
|
-
|
|
253
311
|
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
254
312
|
|
|
255
313
|
mcmc.run(keys[0])
|
|
256
|
-
posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
|
|
314
|
+
posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
|
|
315
|
+
keys[1], observed=False
|
|
316
|
+
)
|
|
257
317
|
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
258
|
-
inference_data = az.from_numpyro(
|
|
318
|
+
inference_data = az.from_numpyro(
|
|
319
|
+
mcmc, prior=prior, posterior_predictive=posterior_predictive
|
|
320
|
+
)
|
|
259
321
|
|
|
260
|
-
inference_data = filter_inference_data(
|
|
322
|
+
inference_data = filter_inference_data(
|
|
323
|
+
inference_data, self._observation_container, self.background_model
|
|
324
|
+
)
|
|
261
325
|
|
|
262
326
|
return FitResult(
|
|
263
327
|
self.model,
|
|
@@ -272,61 +336,242 @@ class MinimizationFitter(ModelFitter):
|
|
|
272
336
|
"""
|
|
273
337
|
A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
|
|
274
338
|
algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
|
|
275
|
-
Hessian of the log-
|
|
339
|
+
Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
|
|
276
340
|
numpyro.
|
|
277
341
|
"""
|
|
278
342
|
|
|
279
343
|
def fit(
|
|
280
344
|
self,
|
|
281
|
-
prior_distributions:
|
|
345
|
+
prior_distributions: PriorDictType,
|
|
282
346
|
rng_key: int = 0,
|
|
283
|
-
num_iter_max: int =
|
|
347
|
+
num_iter_max: int = 100_000,
|
|
284
348
|
num_samples: int = 1_000,
|
|
349
|
+
solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
|
|
350
|
+
init_params=None,
|
|
351
|
+
refine_first_guess=True,
|
|
285
352
|
) -> FitResult:
|
|
286
353
|
"""
|
|
287
354
|
Fit the model to the data using L-BFGS algorithm.
|
|
288
355
|
|
|
289
|
-
Parameters
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
290
358
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
291
359
|
rng_key: the random key used to initialize the sampler.
|
|
292
360
|
num_iter_max: the maximum number of iteration in the minimization algorithm.
|
|
293
361
|
num_samples: the number of sample to draw from the best-fit covariance.
|
|
294
362
|
|
|
295
363
|
Returns:
|
|
364
|
+
-------
|
|
296
365
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
297
366
|
"""
|
|
298
367
|
|
|
299
368
|
bayesian_model = self.numpyro_model(prior_distributions)
|
|
369
|
+
keys = jax.random.split(PRNGKey(rng_key), 4)
|
|
300
370
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
371
|
+
if init_params is not None:
|
|
372
|
+
# We initialize the parameters by randomly sampling from the prior
|
|
373
|
+
local_keys = jax.random.split(keys[0], 2)
|
|
374
|
+
starting_value = sample_prior(
|
|
375
|
+
prior_distributions, key=local_keys[0], flat_parameters=True
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# We update the starting value with the provided init_params
|
|
379
|
+
for m, n, val in hk.data_structures.traverse(init_params):
|
|
380
|
+
if f"{m}_{n}" in starting_value.keys():
|
|
381
|
+
starting_value[f"{m}_{n}"] = val
|
|
382
|
+
|
|
383
|
+
init_params, _ = numpyro.infer.util.find_valid_initial_params(
|
|
384
|
+
local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
else:
|
|
388
|
+
init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
|
|
389
|
+
|
|
390
|
+
init_params = init_params[0]
|
|
307
391
|
|
|
308
|
-
# get negative log-density from the potential function
|
|
309
392
|
@jax.jit
|
|
310
|
-
def
|
|
311
|
-
|
|
312
|
-
|
|
393
|
+
def nll(unconstrained_params, _):
|
|
394
|
+
constrained_params = constrain_fn(
|
|
395
|
+
bayesian_model, tuple(), dict(observed=True), unconstrained_params
|
|
396
|
+
)
|
|
313
397
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
398
|
+
log_likelihood = numpyro.infer.util.log_likelihood(
|
|
399
|
+
model=bayesian_model, posterior_samples=constrained_params
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# We solve a least square problem, this function ensure that the total residual is indeed the nll
|
|
403
|
+
return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
|
|
404
|
+
|
|
405
|
+
"""
|
|
406
|
+
if refine_first_guess:
|
|
407
|
+
with catchtime("Refine_first"):
|
|
408
|
+
solution = optx.least_squares(
|
|
409
|
+
nll,
|
|
410
|
+
optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
|
|
411
|
+
init_params,
|
|
412
|
+
max_steps=1000,
|
|
413
|
+
throw=False
|
|
414
|
+
)
|
|
415
|
+
init_params = solution.value
|
|
416
|
+
"""
|
|
317
417
|
|
|
418
|
+
if solver == "bfgs":
|
|
419
|
+
solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
|
|
420
|
+
elif solver == "levenberg_marquardt":
|
|
421
|
+
solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
|
|
422
|
+
else:
|
|
423
|
+
raise NotImplementedError(f"{solver} is not implemented")
|
|
424
|
+
|
|
425
|
+
with catchtime("Minimization"):
|
|
426
|
+
solution = optx.least_squares(
|
|
427
|
+
nll,
|
|
428
|
+
solver,
|
|
429
|
+
init_params,
|
|
430
|
+
max_steps=num_iter_max,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
params = solution.value
|
|
318
434
|
value_flat, unflatten_fun = ravel_pytree(params)
|
|
319
|
-
covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
|
|
320
435
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
436
|
+
with catchtime("Compute error"):
|
|
437
|
+
precision = jax.hessian(
|
|
438
|
+
lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
|
|
439
|
+
)(value_flat)
|
|
324
440
|
|
|
325
|
-
|
|
326
|
-
|
|
441
|
+
cov = Covariance.from_precision(precision)
|
|
442
|
+
|
|
443
|
+
samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
|
|
444
|
+
|
|
445
|
+
samples = jax.vmap(unflatten_fun)(samples_flat)
|
|
446
|
+
posterior_samples = jax.jit(
|
|
447
|
+
jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
|
|
448
|
+
)(samples)
|
|
449
|
+
|
|
450
|
+
with catchtime("Posterior"):
|
|
451
|
+
posterior_predictive = Predictive(bayesian_model, posterior_samples)(
|
|
452
|
+
keys[2], observed=False
|
|
453
|
+
)
|
|
454
|
+
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
455
|
+
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
456
|
+
|
|
457
|
+
def sanitize_chain(chain):
|
|
458
|
+
"""
|
|
459
|
+
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
460
|
+
"""
|
|
461
|
+
return tree_map(lambda x: x[None, ...], chain)
|
|
462
|
+
|
|
463
|
+
# We export the observed values to the inference_data
|
|
464
|
+
seeded_model = numpyro.handlers.substitute(
|
|
465
|
+
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
466
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
467
|
+
)
|
|
468
|
+
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
469
|
+
observations = {
|
|
470
|
+
name: site["value"]
|
|
471
|
+
for name, site in trace.items()
|
|
472
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
with catchtime("InferenceData wrapping"):
|
|
476
|
+
inference_data = az.from_dict(
|
|
477
|
+
sanitize_chain(posterior_samples),
|
|
478
|
+
prior=sanitize_chain(prior),
|
|
479
|
+
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
480
|
+
log_likelihood=sanitize_chain(log_likelihood),
|
|
481
|
+
observed_data=observations,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
inference_data = filter_inference_data(
|
|
485
|
+
inference_data, self._observation_container, self.background_model
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
return FitResult(
|
|
489
|
+
self.model,
|
|
490
|
+
self._observation_container,
|
|
491
|
+
inference_data,
|
|
492
|
+
self.model.params,
|
|
493
|
+
background_model=self.background_model,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
class NestedSamplingFitter(ModelFitter):
|
|
498
|
+
r"""
|
|
499
|
+
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
500
|
+
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
501
|
+
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
502
|
+
Add Citation to jaxns
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
def fit(
|
|
506
|
+
self,
|
|
507
|
+
prior_distributions: PriorDictType,
|
|
508
|
+
rng_key: int = 0,
|
|
509
|
+
num_parallel_workers: int = len(jax.devices()),
|
|
510
|
+
num_samples: int = 1000,
|
|
511
|
+
plot_diagnostics=False,
|
|
512
|
+
verbose=True,
|
|
513
|
+
) -> FitResult:
|
|
514
|
+
"""
|
|
515
|
+
Fit the model to the data using the Phantom-Powered nested sampling algorithm.
|
|
516
|
+
|
|
517
|
+
Parameters:
|
|
518
|
+
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
519
|
+
rng_key: the random key used to initialize the sampler.
|
|
520
|
+
num_samples: the number of samples to draw.
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
bayesian_model = self.transformed_numpyro_model(prior_distributions)
|
|
527
|
+
keys = random.split(random.PRNGKey(rng_key), 4)
|
|
528
|
+
|
|
529
|
+
ns = NestedSampler(
|
|
530
|
+
bayesian_model,
|
|
531
|
+
constructor_kwargs=dict(
|
|
532
|
+
num_parallel_workers=num_parallel_workers,
|
|
533
|
+
verbose=verbose,
|
|
534
|
+
difficult_model=True,
|
|
535
|
+
# max_samples=1e6,
|
|
536
|
+
# num_live_points=10_000,
|
|
537
|
+
# init_efficiency_threshold=0.5,
|
|
538
|
+
parameter_estimation=True,
|
|
539
|
+
),
|
|
540
|
+
termination_kwargs=dict(dlogZ=1e-2),
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
ns.run(keys[0])
|
|
544
|
+
|
|
545
|
+
self.ns = ns
|
|
546
|
+
|
|
547
|
+
if plot_diagnostics:
|
|
548
|
+
ns.diagnostics()
|
|
549
|
+
|
|
550
|
+
posterior_samples = ns.get_samples(keys[1], num_samples=num_samples * num_parallel_workers)
|
|
327
551
|
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
552
|
+
posterior_predictive = Predictive(bayesian_model, posterior_samples)(
|
|
553
|
+
keys[2], observed=False
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
prior = Predictive(bayesian_model, num_samples=num_samples * num_parallel_workers)(
|
|
557
|
+
keys[3], observed=False
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
seeded_model = numpyro.handlers.substitute(
|
|
561
|
+
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
562
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
563
|
+
)
|
|
564
|
+
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
565
|
+
observations = {
|
|
566
|
+
name: site["value"]
|
|
567
|
+
for name, site in trace.items()
|
|
568
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
569
|
+
}
|
|
328
570
|
|
|
329
571
|
def sanitize_chain(chain):
|
|
572
|
+
"""
|
|
573
|
+
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
574
|
+
"""
|
|
330
575
|
return tree_map(lambda x: x[None, ...], chain)
|
|
331
576
|
|
|
332
577
|
inference_data = az.from_dict(
|
|
@@ -334,9 +579,12 @@ class MinimizationFitter(ModelFitter):
|
|
|
334
579
|
prior=sanitize_chain(prior),
|
|
335
580
|
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
336
581
|
log_likelihood=sanitize_chain(log_likelihood),
|
|
582
|
+
observed_data=observations,
|
|
337
583
|
)
|
|
338
584
|
|
|
339
|
-
inference_data = filter_inference_data(
|
|
585
|
+
inference_data = filter_inference_data(
|
|
586
|
+
inference_data, self._observation_container, self.background_model
|
|
587
|
+
)
|
|
340
588
|
|
|
341
589
|
return FitResult(
|
|
342
590
|
self.model,
|
jaxspec/model/__init__.py
CHANGED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|