jaxspec 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/__init__.py +1 -1
- jaxspec/analysis/compare.py +3 -3
- jaxspec/analysis/results.py +239 -110
- jaxspec/data/example_data/fakeit.pha +335 -1
- jaxspec/data/instrument.py +0 -2
- jaxspec/data/obsconf.py +101 -71
- jaxspec/data/observation.py +24 -6
- jaxspec/data/ogip.py +18 -0
- jaxspec/data/util.py +28 -34
- jaxspec/fit.py +166 -72
- jaxspec/model/_additive/__init__.py +0 -0
- jaxspec/model/_additive/apec.py +377 -0
- jaxspec/model/_additive/apec_loaders.py +90 -0
- jaxspec/model/abc.py +55 -7
- jaxspec/model/additive.py +6 -56
- jaxspec/tables/abundances.dat +31 -0
- jaxspec/tables/{apec.nc → new_apec.nc} +0 -0
- jaxspec/util/abundance.py +111 -0
- jaxspec/util/integrate.py +5 -4
- {jaxspec-0.0.3.dist-info → jaxspec-0.0.5.dist-info}/METADATA +8 -4
- {jaxspec-0.0.3.dist-info → jaxspec-0.0.5.dist-info}/RECORD +23 -17
- {jaxspec-0.0.3.dist-info → jaxspec-0.0.5.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.3.dist-info → jaxspec-0.0.5.dist-info}/WHEEL +0 -0
jaxspec/data/util.py
CHANGED
|
@@ -6,9 +6,9 @@ import haiku as hk
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from numpy.typing import ArrayLike
|
|
8
8
|
from collections.abc import Mapping
|
|
9
|
-
from typing import TypeVar
|
|
9
|
+
from typing import TypeVar, Tuple
|
|
10
|
+
from astropy.io import fits
|
|
10
11
|
|
|
11
|
-
from .ogip import DataPHA, DataARF, DataRMF
|
|
12
12
|
from . import Observation, Instrument, ObsConfiguration
|
|
13
13
|
from ..model.abc import SpectralModel
|
|
14
14
|
from ..fit import CountForwardModel
|
|
@@ -104,6 +104,7 @@ def fakeit(
|
|
|
104
104
|
model: SpectralModel,
|
|
105
105
|
parameters: Mapping[K, V],
|
|
106
106
|
rng_key: int = 0,
|
|
107
|
+
sparsify_matrix: bool = False,
|
|
107
108
|
) -> ArrayLike | list[ArrayLike]:
|
|
108
109
|
"""
|
|
109
110
|
This function is a convenience function that allows to simulate spectra from a given model and a set of parameters.
|
|
@@ -116,13 +117,16 @@ def fakeit(
|
|
|
116
117
|
model: The model to use.
|
|
117
118
|
parameters: The parameters of the model.
|
|
118
119
|
rng_key: The random number generator seed.
|
|
120
|
+
sparsify_matrix: Whether to sparsify the matrix or not.
|
|
119
121
|
"""
|
|
120
122
|
|
|
121
123
|
instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
|
|
122
124
|
fakeits = []
|
|
123
125
|
|
|
124
126
|
for i, instrument in enumerate(instruments):
|
|
125
|
-
transformed_model = hk.without_apply_rng(
|
|
127
|
+
transformed_model = hk.without_apply_rng(
|
|
128
|
+
hk.transform(lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par))
|
|
129
|
+
)
|
|
126
130
|
|
|
127
131
|
def obs_model(p):
|
|
128
132
|
return transformed_model.apply(None, p)
|
|
@@ -159,7 +163,8 @@ def fakeit_for_multiple_parameters(
|
|
|
159
163
|
model: SpectralModel,
|
|
160
164
|
parameters: Mapping[K, V],
|
|
161
165
|
rng_key: int = 0,
|
|
162
|
-
apply_stat=True,
|
|
166
|
+
apply_stat: bool = True,
|
|
167
|
+
sparsify_matrix: bool = False,
|
|
163
168
|
):
|
|
164
169
|
"""
|
|
165
170
|
This function is a convenience function that allows to simulate spectra multiple spectra from a given model and a
|
|
@@ -173,13 +178,16 @@ def fakeit_for_multiple_parameters(
|
|
|
173
178
|
parameters: The parameters of the model.
|
|
174
179
|
rng_key: The random number generator seed.
|
|
175
180
|
apply_stat: Whether to apply Poisson statistic on the folded spectra or not.
|
|
181
|
+
sparsify_matrix: Whether to sparsify the matrix or not.
|
|
176
182
|
"""
|
|
177
183
|
|
|
178
184
|
instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
|
|
179
185
|
fakeits = []
|
|
180
186
|
|
|
181
187
|
for i, obs in enumerate(instruments):
|
|
182
|
-
transformed_model = hk.without_apply_rng(
|
|
188
|
+
transformed_model = hk.without_apply_rng(
|
|
189
|
+
hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparsify_matrix)(par))
|
|
190
|
+
)
|
|
183
191
|
|
|
184
192
|
@jax.jit
|
|
185
193
|
@jax.vmap
|
|
@@ -201,46 +209,32 @@ def fakeit_for_multiple_parameters(
|
|
|
201
209
|
return fakeits[0] if len(fakeits) == 1 else fakeits
|
|
202
210
|
|
|
203
211
|
|
|
204
|
-
def
|
|
212
|
+
def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
|
|
205
213
|
"""
|
|
206
|
-
This function
|
|
207
|
-
from a given PHA file, using either the ARF/RMF/BKG filenames in the header or the
|
|
208
|
-
specified filenames overwritten by the user.
|
|
209
|
-
|
|
214
|
+
This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
|
|
210
215
|
Parameters:
|
|
211
216
|
pha_path: The PHA file path.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
212
219
|
arf_path: The ARF file path.
|
|
213
220
|
rmf_path: The RMF file path.
|
|
214
221
|
bkg_path: The BKG file path.
|
|
215
222
|
"""
|
|
216
223
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
arf_path = find_file_or_compressed_in_dir(pha.ancrfile, directory)
|
|
223
|
-
|
|
224
|
-
if rmf_path is None:
|
|
225
|
-
if pha.respfile != "none" and pha.respfile != "":
|
|
226
|
-
rmf_path = find_file_or_compressed_in_dir(pha.respfile, directory)
|
|
227
|
-
|
|
228
|
-
if bkg_path is None:
|
|
229
|
-
if pha.backfile.lower() != "none" and pha.backfile != "":
|
|
230
|
-
bkg_path = find_file_or_compressed_in_dir(pha.backfile, directory)
|
|
224
|
+
def find_path(file_name: str, directory: str) -> str | None:
|
|
225
|
+
if file_name.lower() != "none" and file_name != "":
|
|
226
|
+
return find_file_or_compressed_in_dir(file_name, directory)
|
|
227
|
+
else:
|
|
228
|
+
return None
|
|
231
229
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
|
|
230
|
+
header = fits.getheader(pha_path, "SPECTRUM")
|
|
231
|
+
directory = str(Path(pha_path).parent)
|
|
235
232
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
"response_matrix_file": rmf_path,
|
|
240
|
-
"ancillary_response_file": arf_path,
|
|
241
|
-
}
|
|
233
|
+
arf_path = find_path(header.get("ANCRFILE", "none"), directory)
|
|
234
|
+
rmf_path = find_path(header.get("RESPFILE", "none"), directory)
|
|
235
|
+
bkg_path = find_path(header.get("BACKFILE", "none"), directory)
|
|
242
236
|
|
|
243
|
-
return
|
|
237
|
+
return arf_path, rmf_path, bkg_path
|
|
244
238
|
|
|
245
239
|
|
|
246
240
|
def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
|
jaxspec/fit.py
CHANGED
|
@@ -4,11 +4,12 @@ import numpyro
|
|
|
4
4
|
import arviz as az
|
|
5
5
|
import jax
|
|
6
6
|
from typing import Callable, TypeVar
|
|
7
|
-
from abc import ABC
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
8
|
from jax import random
|
|
9
9
|
from jax.tree_util import tree_map
|
|
10
|
+
from jax.flatten_util import ravel_pytree
|
|
10
11
|
from jax.experimental.sparse import BCSR
|
|
11
|
-
from .analysis.results import
|
|
12
|
+
from .analysis.results import FitResult
|
|
12
13
|
from .model.abc import SpectralModel
|
|
13
14
|
from .data import ObsConfiguration
|
|
14
15
|
from .model.background import BackgroundModel
|
|
@@ -17,13 +18,15 @@ from numpyro.distributions import Distribution, TransformedDistribution
|
|
|
17
18
|
from numpyro.distributions import Poisson
|
|
18
19
|
from jax.typing import ArrayLike
|
|
19
20
|
from numpyro.infer.reparam import TransformReparam
|
|
21
|
+
from numpyro.infer.util import initialize_model
|
|
22
|
+
from jax.random import PRNGKey
|
|
23
|
+
import jaxopt
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
T = TypeVar("T")
|
|
23
27
|
|
|
24
28
|
|
|
25
|
-
class HaikuDict(dict[str, dict[str, T]]):
|
|
26
|
-
...
|
|
29
|
+
class HaikuDict(dict[str, dict[str, T]]): ...
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple = ()):
|
|
@@ -32,7 +35,8 @@ def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple
|
|
|
32
35
|
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
33
36
|
match sample:
|
|
34
37
|
case Distribution():
|
|
35
|
-
parameters[m][n] = numpyro.sample(f"{m}_{n}", sample
|
|
38
|
+
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{m}_{n}", sample)
|
|
39
|
+
# parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape)) build a free parameter for each obs
|
|
36
40
|
case float() | ArrayLike():
|
|
37
41
|
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
38
42
|
case _:
|
|
@@ -54,7 +58,7 @@ def build_numpyro_model(
|
|
|
54
58
|
|
|
55
59
|
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
56
60
|
bkg_countrate = background_model.numpyro_model(
|
|
57
|
-
obs.out_energies, obs.folded_background.data, name=
|
|
61
|
+
obs.out_energies, obs.folded_background.data, name="bkg_" + name, observed=observed
|
|
58
62
|
)
|
|
59
63
|
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
60
64
|
raise ValueError("Trying to fit a background model but no background is linked to this observation")
|
|
@@ -66,9 +70,9 @@ def build_numpyro_model(
|
|
|
66
70
|
countrate = obs_model(prior_params)
|
|
67
71
|
|
|
68
72
|
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
|
|
69
|
-
with numpyro.plate(
|
|
73
|
+
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
70
74
|
numpyro.sample(
|
|
71
|
-
|
|
75
|
+
"obs_" + name,
|
|
72
76
|
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
73
77
|
obs=obs.folded_counts.data if observed else None,
|
|
74
78
|
)
|
|
@@ -76,6 +80,25 @@ def build_numpyro_model(
|
|
|
76
80
|
return numpro_model
|
|
77
81
|
|
|
78
82
|
|
|
83
|
+
def filter_inference_data(inference_data, observation_container, background_model=None) -> az.InferenceData:
|
|
84
|
+
predictive_parameters = []
|
|
85
|
+
|
|
86
|
+
for key, value in observation_container.items():
|
|
87
|
+
if background_model is not None:
|
|
88
|
+
predictive_parameters.append(f"obs_{key}")
|
|
89
|
+
predictive_parameters.append(f"bkg_{key}")
|
|
90
|
+
else:
|
|
91
|
+
predictive_parameters.append(f"obs_{key}")
|
|
92
|
+
|
|
93
|
+
inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
|
|
94
|
+
|
|
95
|
+
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
96
|
+
inference_data.posterior = inference_data.posterior[parameters]
|
|
97
|
+
inference_data.prior = inference_data.prior[parameters]
|
|
98
|
+
|
|
99
|
+
return inference_data
|
|
100
|
+
|
|
101
|
+
|
|
79
102
|
class CountForwardModel(hk.Module):
|
|
80
103
|
"""
|
|
81
104
|
A haiku module which allows to build the function that simulates the measured counts
|
|
@@ -97,27 +120,87 @@ class CountForwardModel(hk.Module):
|
|
|
97
120
|
Compute the count functions for a given observation.
|
|
98
121
|
"""
|
|
99
122
|
|
|
100
|
-
expected_counts = self.transfer_matrix @ self.model(parameters, *self.energies)
|
|
123
|
+
expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
|
|
101
124
|
|
|
102
125
|
return jnp.clip(expected_counts, a_min=1e-6)
|
|
103
126
|
|
|
104
127
|
|
|
105
|
-
class
|
|
128
|
+
class ModelFitter(ABC):
|
|
106
129
|
"""
|
|
107
130
|
Abstract class to fit a model to a given set of observation.
|
|
108
131
|
"""
|
|
109
132
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
model: SpectralModel,
|
|
136
|
+
observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
|
|
137
|
+
background_model: BackgroundModel = None,
|
|
138
|
+
sparsify_matrix: bool = False,
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Initialize the fitter.
|
|
117
142
|
|
|
118
|
-
|
|
143
|
+
Parameters:
|
|
144
|
+
model: the spectral model to fit.
|
|
145
|
+
observations: the observations to fit the model to.
|
|
146
|
+
background_model: the background model to fit.
|
|
147
|
+
sparsify_matrix: whether to sparsify the transfer matrix.
|
|
148
|
+
"""
|
|
119
149
|
self.model = model
|
|
150
|
+
self._observations = observations
|
|
151
|
+
self.background_model = background_model
|
|
120
152
|
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
153
|
+
self.sparse = sparsify_matrix
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def _observation_container(self) -> dict[str, ObsConfiguration]:
|
|
157
|
+
"""
|
|
158
|
+
The observations used in the fit as a dictionary of observations.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
if isinstance(self._observations, dict):
|
|
162
|
+
return self._observations
|
|
163
|
+
|
|
164
|
+
elif isinstance(self._observations, list):
|
|
165
|
+
return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
|
|
166
|
+
|
|
167
|
+
elif isinstance(self._observations, ObsConfiguration):
|
|
168
|
+
return {"data": self._observations}
|
|
169
|
+
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
172
|
+
|
|
173
|
+
def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
|
|
174
|
+
"""
|
|
175
|
+
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
176
|
+
|
|
177
|
+
Parameters:
|
|
178
|
+
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
A model function that can be used with numpyro.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def model(observed=True):
|
|
185
|
+
prior_params = build_prior(prior_distributions, expand_shape=(len(self._observation_container),))
|
|
186
|
+
|
|
187
|
+
for i, (key, observation) in enumerate(self._observation_container.items()):
|
|
188
|
+
params = tree_map(lambda x: x[i], prior_params)
|
|
189
|
+
|
|
190
|
+
obs_model = build_numpyro_model(observation, self.model, self.background_model, name=key, sparse=self.sparse)
|
|
191
|
+
obs_model(params, observed=observed)
|
|
192
|
+
|
|
193
|
+
return model
|
|
194
|
+
|
|
195
|
+
@abstractmethod
|
|
196
|
+
def fit(self, prior_distributions: HaikuDict[Distribution], **kwargs) -> FitResult: ...
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class BayesianFitter(ModelFitter):
|
|
200
|
+
"""
|
|
201
|
+
A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
|
|
202
|
+
from numpyro to perform the inference on the model parameters.
|
|
203
|
+
"""
|
|
121
204
|
|
|
122
205
|
def fit(
|
|
123
206
|
self,
|
|
@@ -128,11 +211,11 @@ class BayesianModelAbstract(ABC):
|
|
|
128
211
|
num_samples: int = 1000,
|
|
129
212
|
max_tree_depth: int = 10,
|
|
130
213
|
target_accept_prob: float = 0.8,
|
|
131
|
-
dense_mass=False,
|
|
214
|
+
dense_mass: bool = False,
|
|
132
215
|
mcmc_kwargs: dict = {},
|
|
133
|
-
) ->
|
|
216
|
+
) -> FitResult:
|
|
134
217
|
"""
|
|
135
|
-
Fit the model to the data using NUTS sampler from numpyro.
|
|
218
|
+
Fit the model to the data using NUTS sampler from numpyro.
|
|
136
219
|
|
|
137
220
|
Parameters:
|
|
138
221
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
@@ -146,7 +229,7 @@ class BayesianModelAbstract(ABC):
|
|
|
146
229
|
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
147
230
|
|
|
148
231
|
Returns:
|
|
149
|
-
A [`
|
|
232
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
150
233
|
"""
|
|
151
234
|
|
|
152
235
|
transform_dict = {}
|
|
@@ -174,80 +257,91 @@ class BayesianModelAbstract(ABC):
|
|
|
174
257
|
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
175
258
|
inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
|
|
176
259
|
|
|
177
|
-
|
|
178
|
-
inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
|
|
179
|
-
|
|
180
|
-
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
181
|
-
inference_data.posterior = inference_data.posterior[parameters]
|
|
182
|
-
inference_data.prior = inference_data.prior[parameters]
|
|
260
|
+
inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
|
|
183
261
|
|
|
184
|
-
return
|
|
262
|
+
return FitResult(
|
|
185
263
|
self.model,
|
|
186
|
-
self.
|
|
264
|
+
self._observation_container,
|
|
187
265
|
inference_data,
|
|
188
|
-
mcmc.get_samples(),
|
|
189
266
|
self.model.params,
|
|
190
267
|
background_model=self.background_model,
|
|
191
268
|
)
|
|
192
269
|
|
|
193
270
|
|
|
194
|
-
class
|
|
271
|
+
class MinimizationFitter(ModelFitter):
|
|
195
272
|
"""
|
|
196
|
-
|
|
273
|
+
A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
|
|
274
|
+
algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
|
|
275
|
+
Hessian of the log-likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
|
|
276
|
+
numpyro.
|
|
197
277
|
"""
|
|
198
278
|
|
|
199
|
-
def
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
|
|
279
|
+
def fit(
|
|
280
|
+
self,
|
|
281
|
+
prior_distributions: HaikuDict[Distribution],
|
|
282
|
+
rng_key: int = 0,
|
|
283
|
+
num_iter_max: int = 10_000,
|
|
284
|
+
num_samples: int = 1_000,
|
|
285
|
+
) -> FitResult:
|
|
207
286
|
"""
|
|
208
|
-
|
|
209
|
-
to fit the model using numpyro's various samplers.
|
|
287
|
+
Fit the model to the data using L-BFGS algorithm.
|
|
210
288
|
|
|
211
289
|
Parameters:
|
|
212
290
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
prior_params = build_prior(prior_distributions)
|
|
217
|
-
obs_model = build_numpyro_model(self.observation, self.model, self.background_model, sparse=self.sparse)
|
|
218
|
-
obs_model(prior_params, observed=observed)
|
|
219
|
-
|
|
220
|
-
return model
|
|
291
|
+
rng_key: the random key used to initialize the sampler.
|
|
292
|
+
num_iter_max: the maximum number of iteration in the minimization algorithm.
|
|
293
|
+
num_samples: the number of sample to draw from the best-fit covariance.
|
|
221
294
|
|
|
295
|
+
Returns:
|
|
296
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
297
|
+
"""
|
|
222
298
|
|
|
223
|
-
|
|
224
|
-
class MultipleObservationMCMC(BayesianModelAbstract):
|
|
299
|
+
bayesian_model = self.numpyro_model(prior_distributions)
|
|
225
300
|
|
|
226
|
-
|
|
301
|
+
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
|
|
302
|
+
PRNGKey(0),
|
|
303
|
+
bayesian_model,
|
|
304
|
+
model_args=tuple(),
|
|
305
|
+
dynamic_args=True, # <- this is important!
|
|
306
|
+
)
|
|
227
307
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
308
|
+
# get negative log-density from the potential function
|
|
309
|
+
@jax.jit
|
|
310
|
+
def nll_fn(position):
|
|
311
|
+
func = potential_fn()
|
|
312
|
+
return func(position)
|
|
232
313
|
|
|
233
|
-
|
|
314
|
+
solver = jaxopt.LBFGS(fun=nll_fn, maxiter=10_000)
|
|
315
|
+
params, state = solver.run(param_info.z)
|
|
316
|
+
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
234
317
|
|
|
235
|
-
|
|
318
|
+
value_flat, unflatten_fun = ravel_pytree(params)
|
|
319
|
+
covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
|
|
236
320
|
|
|
237
|
-
|
|
321
|
+
samples_flat = jax.random.multivariate_normal(keys[0], value_flat, covariance, shape=(num_samples,))
|
|
322
|
+
samples = jax.vmap(unflatten_fun)(samples_flat.block_until_ready())
|
|
323
|
+
posterior_samples = postprocess_fn()(samples)
|
|
238
324
|
|
|
239
|
-
|
|
325
|
+
posterior_predictive = Predictive(bayesian_model, posterior_samples)(keys[1], observed=False)
|
|
326
|
+
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
327
|
+
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
240
328
|
|
|
241
|
-
|
|
329
|
+
def sanitize_chain(chain):
|
|
330
|
+
return tree_map(lambda x: x[None, ...], chain)
|
|
242
331
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
332
|
+
inference_data = az.from_dict(
|
|
333
|
+
sanitize_chain(posterior_samples),
|
|
334
|
+
prior=sanitize_chain(prior),
|
|
335
|
+
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
336
|
+
log_likelihood=sanitize_chain(log_likelihood),
|
|
337
|
+
)
|
|
249
338
|
|
|
250
|
-
|
|
339
|
+
inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
|
|
251
340
|
|
|
252
|
-
return
|
|
253
|
-
|
|
341
|
+
return FitResult(
|
|
342
|
+
self.model,
|
|
343
|
+
self._observation_container,
|
|
344
|
+
inference_data,
|
|
345
|
+
self.model.params,
|
|
346
|
+
background_model=self.background_model,
|
|
347
|
+
)
|
|
File without changes
|