jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +297 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +114 -84
- jaxspec/fit.py +335 -96
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +194 -55
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +63 -41
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +28 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/METADATA +14 -10
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/RECORD +19 -25
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/WHEEL +1 -1
- jaxspec/data/example_data/MOS1.pha +0 -46
- jaxspec/data/example_data/MOS2.pha +0 -42
- jaxspec/data/example_data/PN.pha +1 -293
- jaxspec/data/example_data/fakeit.pha +1 -335
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/LICENSE.md +0 -0
jaxspec/fit.py
CHANGED
|
@@ -1,67 +1,84 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import arviz as az
|
|
1
6
|
import haiku as hk
|
|
7
|
+
import jax
|
|
2
8
|
import jax.numpy as jnp
|
|
3
9
|
import numpyro
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
from typing import Callable, TypeVar
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
10
|
+
import optimistix as optx
|
|
11
|
+
|
|
8
12
|
from jax import random
|
|
9
|
-
from jax.
|
|
13
|
+
from jax.experimental.sparse import BCOO
|
|
10
14
|
from jax.flatten_util import ravel_pytree
|
|
11
|
-
from jax.
|
|
12
|
-
from .
|
|
13
|
-
from .model.abc import SpectralModel
|
|
14
|
-
from .data import ObsConfiguration
|
|
15
|
-
from .model.background import BackgroundModel
|
|
16
|
-
from numpyro.infer import MCMC, NUTS, Predictive
|
|
17
|
-
from numpyro.distributions import Distribution, TransformedDistribution
|
|
18
|
-
from numpyro.distributions import Poisson
|
|
15
|
+
from jax.random import PRNGKey
|
|
16
|
+
from jax.tree_util import tree_map
|
|
19
17
|
from jax.typing import ArrayLike
|
|
18
|
+
from numpyro.contrib.nested_sampling import NestedSampler
|
|
19
|
+
from numpyro.distributions import Distribution, Poisson, TransformedDistribution
|
|
20
|
+
from numpyro.infer import MCMC, NUTS, Predictive
|
|
21
|
+
from numpyro.infer.initialization import init_to_value
|
|
22
|
+
from numpyro.infer.inspect import get_model_relations
|
|
20
23
|
from numpyro.infer.reparam import TransformReparam
|
|
21
|
-
from numpyro.infer.util import
|
|
22
|
-
from
|
|
23
|
-
import jaxopt
|
|
24
|
+
from numpyro.infer.util import constrain_fn
|
|
25
|
+
from scipy.stats import Covariance, multivariate_normal
|
|
24
26
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
27
|
+
from .analysis.results import FitResult
|
|
28
|
+
from .data import ObsConfiguration
|
|
29
|
+
from .model.abc import SpectralModel
|
|
30
|
+
from .model.background import BackgroundModel
|
|
31
|
+
from .util import catchtime
|
|
32
|
+
from .util.typing import PriorDictModel, PriorDictType
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
def build_prior(prior:
|
|
35
|
+
def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
|
|
36
|
+
"""
|
|
37
|
+
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
38
|
+
Must be used within a numpyro model.
|
|
39
|
+
"""
|
|
33
40
|
parameters = dict(hk.data_structures.to_haiku_dict(prior))
|
|
34
41
|
|
|
35
42
|
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
43
|
+
if isinstance(sample, Distribution):
|
|
44
|
+
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
|
|
45
|
+
|
|
46
|
+
elif isinstance(sample, ArrayLike):
|
|
47
|
+
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
|
|
52
|
+
)
|
|
44
53
|
|
|
45
54
|
return parameters
|
|
46
55
|
|
|
47
56
|
|
|
48
|
-
def
|
|
57
|
+
def build_numpyro_model_for_single_obs(
|
|
49
58
|
obs: ObsConfiguration,
|
|
50
59
|
model: SpectralModel,
|
|
51
60
|
background_model: BackgroundModel,
|
|
52
61
|
name: str = "",
|
|
53
62
|
sparse: bool = False,
|
|
54
63
|
) -> Callable:
|
|
55
|
-
|
|
64
|
+
"""
|
|
65
|
+
Build a numpyro model for a given observation and spectral model.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def numpyro_model(prior_params, observed=True):
|
|
56
69
|
# prior_params = build_prior(prior_distributions, name=name)
|
|
57
|
-
transformed_model = hk.without_apply_rng(
|
|
70
|
+
transformed_model = hk.without_apply_rng(
|
|
71
|
+
hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
|
|
72
|
+
)
|
|
58
73
|
|
|
59
74
|
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
60
75
|
bkg_countrate = background_model.numpyro_model(
|
|
61
|
-
obs
|
|
76
|
+
obs, model, name="bkg_" + name, observed=observed
|
|
62
77
|
)
|
|
63
78
|
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
64
|
-
raise ValueError(
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Trying to fit a background model but no background is linked to this observation"
|
|
81
|
+
)
|
|
65
82
|
|
|
66
83
|
else:
|
|
67
84
|
bkg_countrate = 0.0
|
|
@@ -77,10 +94,12 @@ def build_numpyro_model(
|
|
|
77
94
|
obs=obs.folded_counts.data if observed else None,
|
|
78
95
|
)
|
|
79
96
|
|
|
80
|
-
return
|
|
97
|
+
return numpyro_model
|
|
81
98
|
|
|
82
99
|
|
|
83
|
-
def filter_inference_data(
|
|
100
|
+
def filter_inference_data(
|
|
101
|
+
inference_data, observation_container, background_model=None
|
|
102
|
+
) -> az.InferenceData:
|
|
84
103
|
predictive_parameters = []
|
|
85
104
|
|
|
86
105
|
for key, value in observation_container.items():
|
|
@@ -109,8 +128,12 @@ class CountForwardModel(hk.Module):
|
|
|
109
128
|
self.model = model
|
|
110
129
|
self.energies = jnp.asarray(folding.in_energies)
|
|
111
130
|
|
|
112
|
-
if
|
|
113
|
-
|
|
131
|
+
if (
|
|
132
|
+
sparse
|
|
133
|
+
): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
134
|
+
self.transfer_matrix = BCOO.from_scipy_sparse(
|
|
135
|
+
folding.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
136
|
+
)
|
|
114
137
|
|
|
115
138
|
else:
|
|
116
139
|
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
@@ -125,14 +148,15 @@ class CountForwardModel(hk.Module):
|
|
|
125
148
|
return jnp.clip(expected_counts, a_min=1e-6)
|
|
126
149
|
|
|
127
150
|
|
|
128
|
-
class
|
|
151
|
+
class BayesianModel:
|
|
129
152
|
"""
|
|
130
|
-
|
|
153
|
+
Class to fit a model to a given set of observation.
|
|
131
154
|
"""
|
|
132
155
|
|
|
133
156
|
def __init__(
|
|
134
157
|
self,
|
|
135
158
|
model: SpectralModel,
|
|
159
|
+
prior_distributions: PriorDictType | Callable,
|
|
136
160
|
observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
|
|
137
161
|
background_model: BackgroundModel = None,
|
|
138
162
|
sparsify_matrix: bool = False,
|
|
@@ -142,6 +166,8 @@ class ModelFitter(ABC):
|
|
|
142
166
|
|
|
143
167
|
Parameters:
|
|
144
168
|
model: the spectral model to fit.
|
|
169
|
+
prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
|
|
170
|
+
callable function that returns parameter samples.
|
|
145
171
|
observations: the observations to fit the model to.
|
|
146
172
|
background_model: the background model to fit.
|
|
147
173
|
sparsify_matrix: whether to sparsify the transfer matrix.
|
|
@@ -152,8 +178,20 @@ class ModelFitter(ABC):
|
|
|
152
178
|
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
153
179
|
self.sparse = sparsify_matrix
|
|
154
180
|
|
|
181
|
+
if not callable(prior_distributions):
|
|
182
|
+
# Validate the entry with pydantic
|
|
183
|
+
prior = PriorDictModel(nested_dict=prior_distributions).nested_dict
|
|
184
|
+
|
|
185
|
+
def prior_distributions_func():
|
|
186
|
+
return build_prior(prior, expand_shape=(len(self.observation_container),))
|
|
187
|
+
|
|
188
|
+
else:
|
|
189
|
+
prior_distributions_func = prior_distributions
|
|
190
|
+
|
|
191
|
+
self.prior_distributions_func = prior_distributions_func
|
|
192
|
+
|
|
155
193
|
@property
|
|
156
|
-
def
|
|
194
|
+
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
157
195
|
"""
|
|
158
196
|
The observations used in the fit as a dictionary of observations.
|
|
159
197
|
"""
|
|
@@ -170,33 +208,56 @@ class ModelFitter(ABC):
|
|
|
170
208
|
else:
|
|
171
209
|
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
172
210
|
|
|
173
|
-
|
|
211
|
+
@property
|
|
212
|
+
def numpyro_model(self) -> Callable:
|
|
174
213
|
"""
|
|
175
214
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
176
215
|
|
|
177
|
-
Parameters:
|
|
178
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
179
|
-
|
|
180
216
|
Returns:
|
|
217
|
+
-------
|
|
181
218
|
A model function that can be used with numpyro.
|
|
182
219
|
"""
|
|
183
220
|
|
|
184
221
|
def model(observed=True):
|
|
185
|
-
prior_params =
|
|
222
|
+
prior_params = self.prior_distributions_func()
|
|
186
223
|
|
|
187
|
-
|
|
224
|
+
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
225
|
+
for i, (key, observation) in enumerate(self.observation_container.items()):
|
|
226
|
+
# We expect that prior_params contains an array of parameters for each observation
|
|
227
|
+
# They can be identical or different for each observation
|
|
188
228
|
params = tree_map(lambda x: x[i], prior_params)
|
|
189
229
|
|
|
190
|
-
obs_model =
|
|
230
|
+
obs_model = build_numpyro_model_for_single_obs(
|
|
231
|
+
observation, self.model, self.background_model, name=key, sparse=self.sparse
|
|
232
|
+
)
|
|
233
|
+
|
|
191
234
|
obs_model(params, observed=observed)
|
|
192
235
|
|
|
193
236
|
return model
|
|
194
237
|
|
|
238
|
+
@property
|
|
239
|
+
def transformed_numpyro_model(self) -> Callable:
|
|
240
|
+
transform_dict = {}
|
|
241
|
+
|
|
242
|
+
relations = get_model_relations(self.numpyro_model)
|
|
243
|
+
distributions = {
|
|
244
|
+
parameter: getattr(numpyro.distributions, value, None)
|
|
245
|
+
for parameter, value in relations["sample_dist"].items()
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
for parameter, distribution in distributions.items():
|
|
249
|
+
if isinstance(distribution, TransformedDistribution):
|
|
250
|
+
transform_dict[parameter] = TransformReparam()
|
|
251
|
+
|
|
252
|
+
return numpyro.handlers.reparam(self.numpyro_model, config=transform_dict)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class BayesianModelFitter(BayesianModel, ABC):
|
|
195
256
|
@abstractmethod
|
|
196
|
-
def fit(self,
|
|
257
|
+
def fit(self, **kwargs) -> FitResult: ...
|
|
197
258
|
|
|
198
259
|
|
|
199
|
-
class
|
|
260
|
+
class NUTSFitter(BayesianModelFitter):
|
|
200
261
|
"""
|
|
201
262
|
A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
|
|
202
263
|
from numpyro to perform the inference on the model parameters.
|
|
@@ -204,21 +265,20 @@ class BayesianFitter(ModelFitter):
|
|
|
204
265
|
|
|
205
266
|
def fit(
|
|
206
267
|
self,
|
|
207
|
-
prior_distributions: HaikuDict[Distribution],
|
|
208
268
|
rng_key: int = 0,
|
|
209
|
-
num_chains: int =
|
|
269
|
+
num_chains: int = len(jax.devices()),
|
|
210
270
|
num_warmup: int = 1000,
|
|
211
271
|
num_samples: int = 1000,
|
|
212
272
|
max_tree_depth: int = 10,
|
|
213
273
|
target_accept_prob: float = 0.8,
|
|
214
274
|
dense_mass: bool = False,
|
|
275
|
+
kernel_kwargs: dict = {},
|
|
215
276
|
mcmc_kwargs: dict = {},
|
|
216
277
|
) -> FitResult:
|
|
217
278
|
"""
|
|
218
279
|
Fit the model to the data using NUTS sampler from numpyro.
|
|
219
280
|
|
|
220
281
|
Parameters:
|
|
221
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
222
282
|
rng_key: the random key used to initialize the sampler.
|
|
223
283
|
num_chains: the number of chains to run.
|
|
224
284
|
num_warmup: the number of warmup steps.
|
|
@@ -232,11 +292,8 @@ class BayesianFitter(ModelFitter):
|
|
|
232
292
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
233
293
|
"""
|
|
234
294
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
for m, n, val in hk.data_structures.traverse(prior_distributions):
|
|
238
|
-
if isinstance(val, TransformedDistribution):
|
|
239
|
-
transform_dict[f"{m}_{n}"] = TransformReparam()
|
|
295
|
+
bayesian_model = self.transformed_numpyro_model
|
|
296
|
+
# bayesian_model = self.numpyro_model(prior_distributions)
|
|
240
297
|
|
|
241
298
|
chain_kwargs = {
|
|
242
299
|
"num_warmup": num_warmup,
|
|
@@ -244,50 +301,62 @@ class BayesianFitter(ModelFitter):
|
|
|
244
301
|
"num_chains": num_chains,
|
|
245
302
|
}
|
|
246
303
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
304
|
+
kernel = NUTS(
|
|
305
|
+
bayesian_model,
|
|
306
|
+
max_tree_depth=max_tree_depth,
|
|
307
|
+
target_accept_prob=target_accept_prob,
|
|
308
|
+
dense_mass=dense_mass,
|
|
309
|
+
**kernel_kwargs,
|
|
310
|
+
)
|
|
250
311
|
|
|
251
312
|
mcmc = MCMC(kernel, **(chain_kwargs | mcmc_kwargs))
|
|
252
|
-
|
|
253
313
|
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
254
314
|
|
|
255
315
|
mcmc.run(keys[0])
|
|
256
|
-
|
|
316
|
+
|
|
317
|
+
posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
|
|
318
|
+
keys[1], observed=False
|
|
319
|
+
)
|
|
320
|
+
|
|
257
321
|
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
258
|
-
inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
|
|
259
322
|
|
|
260
|
-
inference_data =
|
|
323
|
+
inference_data = az.from_numpyro(
|
|
324
|
+
mcmc, prior=prior, posterior_predictive=posterior_predictive
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
inference_data = filter_inference_data(
|
|
328
|
+
inference_data, self.observation_container, self.background_model
|
|
329
|
+
)
|
|
261
330
|
|
|
262
331
|
return FitResult(
|
|
263
|
-
self
|
|
264
|
-
self._observation_container,
|
|
332
|
+
self,
|
|
265
333
|
inference_data,
|
|
266
334
|
self.model.params,
|
|
267
335
|
background_model=self.background_model,
|
|
268
336
|
)
|
|
269
337
|
|
|
270
338
|
|
|
271
|
-
class MinimizationFitter(
|
|
339
|
+
class MinimizationFitter(BayesianModelFitter):
|
|
272
340
|
"""
|
|
273
341
|
A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
|
|
274
342
|
algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
|
|
275
|
-
Hessian of the log-
|
|
343
|
+
Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
|
|
276
344
|
numpyro.
|
|
277
345
|
"""
|
|
278
346
|
|
|
279
347
|
def fit(
|
|
280
348
|
self,
|
|
281
|
-
prior_distributions: HaikuDict[Distribution],
|
|
282
349
|
rng_key: int = 0,
|
|
283
|
-
num_iter_max: int =
|
|
350
|
+
num_iter_max: int = 100_000,
|
|
284
351
|
num_samples: int = 1_000,
|
|
352
|
+
solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
|
|
353
|
+
init_params=None,
|
|
354
|
+
refine_first_guess=True,
|
|
285
355
|
) -> FitResult:
|
|
286
356
|
"""
|
|
287
357
|
Fit the model to the data using L-BFGS algorithm.
|
|
288
358
|
|
|
289
359
|
Parameters:
|
|
290
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
291
360
|
rng_key: the random key used to initialize the sampler.
|
|
292
361
|
num_iter_max: the maximum number of iteration in the minimization algorithm.
|
|
293
362
|
num_samples: the number of sample to draw from the best-fit covariance.
|
|
@@ -296,37 +365,205 @@ class MinimizationFitter(ModelFitter):
|
|
|
296
365
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
297
366
|
"""
|
|
298
367
|
|
|
299
|
-
bayesian_model = self.numpyro_model
|
|
368
|
+
bayesian_model = self.numpyro_model
|
|
369
|
+
keys = jax.random.split(PRNGKey(rng_key), 4)
|
|
300
370
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
371
|
+
if init_params is not None:
|
|
372
|
+
# We initialize the parameters by randomly sampling from the prior
|
|
373
|
+
local_keys = jax.random.split(keys[0], 2)
|
|
374
|
+
|
|
375
|
+
with numpyro.handlers.seed(rng_seed=local_keys[0]):
|
|
376
|
+
starting_value = self.prior_distributions_func()
|
|
377
|
+
|
|
378
|
+
# We update the starting value with the provided init_params
|
|
379
|
+
for m, n, val in hk.data_structures.traverse(init_params):
|
|
380
|
+
if f"{m}_{n}" in starting_value.keys():
|
|
381
|
+
starting_value[f"{m}_{n}"] = val
|
|
382
|
+
|
|
383
|
+
init_params, _ = numpyro.infer.util.find_valid_initial_params(
|
|
384
|
+
local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
else:
|
|
388
|
+
init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
|
|
389
|
+
|
|
390
|
+
init_params = init_params[0]
|
|
307
391
|
|
|
308
|
-
# get negative log-density from the potential function
|
|
309
392
|
@jax.jit
|
|
310
|
-
def
|
|
311
|
-
|
|
312
|
-
|
|
393
|
+
def nll(unconstrained_params, _):
|
|
394
|
+
constrained_params = constrain_fn(
|
|
395
|
+
bayesian_model, tuple(), dict(observed=True), unconstrained_params
|
|
396
|
+
)
|
|
313
397
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
398
|
+
log_likelihood = numpyro.infer.util.log_likelihood(
|
|
399
|
+
model=bayesian_model, posterior_samples=constrained_params
|
|
400
|
+
)
|
|
317
401
|
|
|
402
|
+
# We solve a least square problem, this function ensure that the total residual is indeed the nll
|
|
403
|
+
return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
|
|
404
|
+
|
|
405
|
+
"""
|
|
406
|
+
if refine_first_guess:
|
|
407
|
+
with catchtime("Refine_first"):
|
|
408
|
+
solution = optx.least_squares(
|
|
409
|
+
nll,
|
|
410
|
+
optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
|
|
411
|
+
init_params,
|
|
412
|
+
max_steps=1000,
|
|
413
|
+
throw=False
|
|
414
|
+
)
|
|
415
|
+
init_params = solution.value
|
|
416
|
+
"""
|
|
417
|
+
|
|
418
|
+
if solver == "bfgs":
|
|
419
|
+
solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
|
|
420
|
+
elif solver == "levenberg_marquardt":
|
|
421
|
+
solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
|
|
422
|
+
else:
|
|
423
|
+
raise NotImplementedError(f"{solver} is not implemented")
|
|
424
|
+
|
|
425
|
+
with catchtime("Minimization"):
|
|
426
|
+
solution = optx.least_squares(
|
|
427
|
+
nll,
|
|
428
|
+
solver,
|
|
429
|
+
init_params,
|
|
430
|
+
max_steps=num_iter_max,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
params = solution.value
|
|
318
434
|
value_flat, unflatten_fun = ravel_pytree(params)
|
|
319
|
-
covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
|
|
320
435
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
436
|
+
with catchtime("Compute error"):
|
|
437
|
+
precision = jax.hessian(
|
|
438
|
+
lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
|
|
439
|
+
)(value_flat)
|
|
324
440
|
|
|
325
|
-
|
|
326
|
-
|
|
441
|
+
cov = Covariance.from_precision(precision)
|
|
442
|
+
|
|
443
|
+
samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
|
|
444
|
+
|
|
445
|
+
samples = jax.vmap(unflatten_fun)(samples_flat)
|
|
446
|
+
posterior_samples = jax.jit(
|
|
447
|
+
jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
|
|
448
|
+
)(samples)
|
|
449
|
+
|
|
450
|
+
with catchtime("Posterior"):
|
|
451
|
+
posterior_predictive = Predictive(bayesian_model, posterior_samples)(
|
|
452
|
+
keys[2], observed=False
|
|
453
|
+
)
|
|
454
|
+
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
455
|
+
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
456
|
+
|
|
457
|
+
def sanitize_chain(chain):
|
|
458
|
+
"""
|
|
459
|
+
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
460
|
+
"""
|
|
461
|
+
return tree_map(lambda x: x[None, ...], chain)
|
|
462
|
+
|
|
463
|
+
# We export the observed values to the inference_data
|
|
464
|
+
seeded_model = numpyro.handlers.substitute(
|
|
465
|
+
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
466
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
467
|
+
)
|
|
468
|
+
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
469
|
+
observations = {
|
|
470
|
+
name: site["value"]
|
|
471
|
+
for name, site in trace.items()
|
|
472
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
with catchtime("InferenceData wrapping"):
|
|
476
|
+
inference_data = az.from_dict(
|
|
477
|
+
sanitize_chain(posterior_samples),
|
|
478
|
+
prior=sanitize_chain(prior),
|
|
479
|
+
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
480
|
+
log_likelihood=sanitize_chain(log_likelihood),
|
|
481
|
+
observed_data=observations,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
inference_data = filter_inference_data(
|
|
485
|
+
inference_data, self.observation_container, self.background_model
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
return FitResult(
|
|
489
|
+
self,
|
|
490
|
+
inference_data,
|
|
491
|
+
self.model.params,
|
|
492
|
+
background_model=self.background_model,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class NestedSamplingFitter(BayesianModelFitter):
|
|
497
|
+
r"""
|
|
498
|
+
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
499
|
+
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
500
|
+
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
501
|
+
Add Citation to jaxns
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
def fit(
|
|
505
|
+
self,
|
|
506
|
+
rng_key: int = 0,
|
|
507
|
+
num_samples: int = 1000,
|
|
508
|
+
plot_diagnostics=False,
|
|
509
|
+
termination_kwargs: dict | None = None,
|
|
510
|
+
verbose=True,
|
|
511
|
+
) -> FitResult:
|
|
512
|
+
"""
|
|
513
|
+
Fit the model to the data using the Phantom-Powered nested sampling algorithm.
|
|
514
|
+
|
|
515
|
+
Parameters:
|
|
516
|
+
rng_key: the random key used to initialize the sampler.
|
|
517
|
+
num_samples: the number of samples to draw.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
521
|
+
"""
|
|
522
|
+
|
|
523
|
+
bayesian_model = self.transformed_numpyro_model
|
|
524
|
+
keys = random.split(random.PRNGKey(rng_key), 4)
|
|
525
|
+
|
|
526
|
+
ns = NestedSampler(
|
|
527
|
+
bayesian_model,
|
|
528
|
+
constructor_kwargs=dict(
|
|
529
|
+
num_parallel_workers=1,
|
|
530
|
+
verbose=verbose,
|
|
531
|
+
difficult_model=True,
|
|
532
|
+
max_samples=1e6,
|
|
533
|
+
parameter_estimation=True,
|
|
534
|
+
num_live_points=1_000,
|
|
535
|
+
),
|
|
536
|
+
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
ns.run(keys[0])
|
|
540
|
+
|
|
541
|
+
if plot_diagnostics:
|
|
542
|
+
ns.diagnostics()
|
|
543
|
+
|
|
544
|
+
posterior_samples = ns.get_samples(keys[1], num_samples=num_samples)
|
|
327
545
|
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
546
|
+
posterior_predictive = Predictive(bayesian_model, posterior_samples)(
|
|
547
|
+
keys[2], observed=False
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
551
|
+
|
|
552
|
+
seeded_model = numpyro.handlers.substitute(
|
|
553
|
+
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
554
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
555
|
+
)
|
|
556
|
+
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
557
|
+
observations = {
|
|
558
|
+
name: site["value"]
|
|
559
|
+
for name, site in trace.items()
|
|
560
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
561
|
+
}
|
|
328
562
|
|
|
329
563
|
def sanitize_chain(chain):
|
|
564
|
+
"""
|
|
565
|
+
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
566
|
+
"""
|
|
330
567
|
return tree_map(lambda x: x[None, ...], chain)
|
|
331
568
|
|
|
332
569
|
inference_data = az.from_dict(
|
|
@@ -334,13 +571,15 @@ class MinimizationFitter(ModelFitter):
|
|
|
334
571
|
prior=sanitize_chain(prior),
|
|
335
572
|
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
336
573
|
log_likelihood=sanitize_chain(log_likelihood),
|
|
574
|
+
observed_data=observations,
|
|
337
575
|
)
|
|
338
576
|
|
|
339
|
-
inference_data = filter_inference_data(
|
|
577
|
+
inference_data = filter_inference_data(
|
|
578
|
+
inference_data, self.observation_container, self.background_model
|
|
579
|
+
)
|
|
340
580
|
|
|
341
581
|
return FitResult(
|
|
342
|
-
self
|
|
343
|
-
self._observation_container,
|
|
582
|
+
self,
|
|
344
583
|
inference_data,
|
|
345
584
|
self.model.params,
|
|
346
585
|
background_model=self.background_model,
|
jaxspec/model/__init__.py
CHANGED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|