jaxspec 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -4,11 +4,12 @@ import numpyro
4
4
  import arviz as az
5
5
  import jax
6
6
  from typing import Callable, TypeVar
7
- from abc import ABC
7
+ from abc import ABC, abstractmethod
8
8
  from jax import random
9
9
  from jax.tree_util import tree_map
10
+ from jax.flatten_util import ravel_pytree
10
11
  from jax.experimental.sparse import BCSR
11
- from .analysis.results import ChainResult
12
+ from .analysis.results import FitResult
12
13
  from .model.abc import SpectralModel
13
14
  from .data import ObsConfiguration
14
15
  from .model.background import BackgroundModel
@@ -17,13 +18,15 @@ from numpyro.distributions import Distribution, TransformedDistribution
17
18
  from numpyro.distributions import Poisson
18
19
  from jax.typing import ArrayLike
19
20
  from numpyro.infer.reparam import TransformReparam
21
+ from numpyro.infer.util import initialize_model
22
+ from jax.random import PRNGKey
23
+ import jaxopt
20
24
 
21
25
 
22
26
  T = TypeVar("T")
23
27
 
24
28
 
25
- class HaikuDict(dict[str, dict[str, T]]):
26
- ...
29
+ class HaikuDict(dict[str, dict[str, T]]): ...
27
30
 
28
31
 
29
32
  def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple = ()):
@@ -32,7 +35,8 @@ def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple
32
35
  for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
33
36
  match sample:
34
37
  case Distribution():
35
- parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape))
38
+ parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{m}_{n}", sample)
39
+ # parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape)) build a free parameter for each obs
36
40
  case float() | ArrayLike():
37
41
  parameters[m][n] = jnp.ones(expand_shape) * sample
38
42
  case _:
@@ -54,7 +58,7 @@ def build_numpyro_model(
54
58
 
55
59
  if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
56
60
  bkg_countrate = background_model.numpyro_model(
57
- obs.out_energies, obs.folded_background.data, name=name + "bkg", observed=observed
61
+ obs.out_energies, obs.folded_background.data, name="bkg_" + name, observed=observed
58
62
  )
59
63
  elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
60
64
  raise ValueError("Trying to fit a background model but no background is linked to this observation")
@@ -66,9 +70,9 @@ def build_numpyro_model(
66
70
  countrate = obs_model(prior_params)
67
71
 
68
72
  # This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
69
- with numpyro.plate(name + "obs_plate", len(obs.folded_counts)):
73
+ with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
70
74
  numpyro.sample(
71
- name + "obs",
75
+ "obs_" + name,
72
76
  Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
73
77
  obs=obs.folded_counts.data if observed else None,
74
78
  )
@@ -76,6 +80,25 @@ def build_numpyro_model(
76
80
  return numpro_model
77
81
 
78
82
 
83
+ def filter_inference_data(inference_data, observation_container, background_model=None) -> az.InferenceData:
84
+ predictive_parameters = []
85
+
86
+ for key, value in observation_container.items():
87
+ if background_model is not None:
88
+ predictive_parameters.append(f"obs_{key}")
89
+ predictive_parameters.append(f"bkg_{key}")
90
+ else:
91
+ predictive_parameters.append(f"obs_{key}")
92
+
93
+ inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
94
+
95
+ parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
96
+ inference_data.posterior = inference_data.posterior[parameters]
97
+ inference_data.prior = inference_data.prior[parameters]
98
+
99
+ return inference_data
100
+
101
+
79
102
  class CountForwardModel(hk.Module):
80
103
  """
81
104
  A haiku module which allows to build the function that simulates the measured counts
@@ -97,27 +120,87 @@ class CountForwardModel(hk.Module):
97
120
  Compute the count functions for a given observation.
98
121
  """
99
122
 
100
- expected_counts = self.transfer_matrix @ self.model(parameters, *self.energies)
123
+ expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
101
124
 
102
125
  return jnp.clip(expected_counts, a_min=1e-6)
103
126
 
104
127
 
105
- class BayesianModelAbstract(ABC):
128
+ class ModelFitter(ABC):
106
129
  """
107
130
  Abstract class to fit a model to a given set of observation.
108
131
  """
109
132
 
110
- model: SpectralModel
111
- """The model to fit to the data."""
112
- numpyro_model: Callable
113
- """The numpyro model defining the likelihood."""
114
- background_model: BackgroundModel
115
- """The background model."""
116
- pars: dict
133
+ def __init__(
134
+ self,
135
+ model: SpectralModel,
136
+ observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
137
+ background_model: BackgroundModel = None,
138
+ sparsify_matrix: bool = False,
139
+ ):
140
+ """
141
+ Initialize the fitter.
117
142
 
118
- def __init__(self, model: SpectralModel):
143
+ Parameters:
144
+ model: the spectral model to fit.
145
+ observations: the observations to fit the model to.
146
+ background_model: the background model to fit.
147
+ sparsify_matrix: whether to sparsify the transfer matrix.
148
+ """
119
149
  self.model = model
150
+ self._observations = observations
151
+ self.background_model = background_model
120
152
  self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
153
+ self.sparse = sparsify_matrix
154
+
155
+ @property
156
+ def _observation_container(self) -> dict[str, ObsConfiguration]:
157
+ """
158
+ The observations used in the fit as a dictionary of observations.
159
+ """
160
+
161
+ if isinstance(self._observations, dict):
162
+ return self._observations
163
+
164
+ elif isinstance(self._observations, list):
165
+ return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
166
+
167
+ elif isinstance(self._observations, ObsConfiguration):
168
+ return {"data": self._observations}
169
+
170
+ else:
171
+ raise ValueError(f"Invalid type for observations : {type(self._observations)}")
172
+
173
+ def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
174
+ """
175
+ Build the numpyro model using the observed data, the prior distributions and the spectral model.
176
+
177
+ Parameters:
178
+ prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
179
+
180
+ Returns:
181
+ A model function that can be used with numpyro.
182
+ """
183
+
184
+ def model(observed=True):
185
+ prior_params = build_prior(prior_distributions, expand_shape=(len(self._observation_container),))
186
+
187
+ for i, (key, observation) in enumerate(self._observation_container.items()):
188
+ params = tree_map(lambda x: x[i], prior_params)
189
+
190
+ obs_model = build_numpyro_model(observation, self.model, self.background_model, name=key, sparse=self.sparse)
191
+ obs_model(params, observed=observed)
192
+
193
+ return model
194
+
195
+ @abstractmethod
196
+ def fit(self, prior_distributions: HaikuDict[Distribution], **kwargs) -> FitResult: ...
197
+
198
+
199
+ class BayesianFitter(ModelFitter):
200
+ """
201
+ A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
202
+ from numpyro to perform the inference on the model parameters.
203
+ """
121
204
 
122
205
  def fit(
123
206
  self,
@@ -128,11 +211,11 @@ class BayesianModelAbstract(ABC):
128
211
  num_samples: int = 1000,
129
212
  max_tree_depth: int = 10,
130
213
  target_accept_prob: float = 0.8,
131
- dense_mass=False,
214
+ dense_mass: bool = False,
132
215
  mcmc_kwargs: dict = {},
133
- ) -> ChainResult:
216
+ ) -> FitResult:
134
217
  """
135
- Fit the model to the data using NUTS sampler from numpyro. This is the default sampler in jaxspec.
218
+ Fit the model to the data using NUTS sampler from numpyro.
136
219
 
137
220
  Parameters:
138
221
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
@@ -146,7 +229,7 @@ class BayesianModelAbstract(ABC):
146
229
  mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
147
230
 
148
231
  Returns:
149
- A [`ChainResult`][jaxspec.analysis.results.ChainResult] instance containing the results of the fit.
232
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
150
233
  """
151
234
 
152
235
  transform_dict = {}
@@ -174,80 +257,91 @@ class BayesianModelAbstract(ABC):
174
257
  prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
175
258
  inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
176
259
 
177
- predictive_parameters = ["obs", "bkg"] if self.background_model is not None else ["obs"]
178
- inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
179
-
180
- parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
181
- inference_data.posterior = inference_data.posterior[parameters]
182
- inference_data.prior = inference_data.prior[parameters]
260
+ inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
183
261
 
184
- return ChainResult(
262
+ return FitResult(
185
263
  self.model,
186
- self.observation,
264
+ self._observation_container,
187
265
  inference_data,
188
- mcmc.get_samples(),
189
266
  self.model.params,
190
267
  background_model=self.background_model,
191
268
  )
192
269
 
193
270
 
194
- class BayesianModel(BayesianModelAbstract):
271
+ class MinimizationFitter(ModelFitter):
195
272
  """
196
- Class to fit a model to a given observation using a Bayesian approach.
273
+ A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
274
+ algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
275
+ Hessian of the log-likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
276
+ numpyro.
197
277
  """
198
278
 
199
- def __init__(self, model, observation, background_model: BackgroundModel = None, sparsify_matrix: bool = False):
200
- super().__init__(model)
201
- self.observation = observation
202
- self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
203
- self.sparse = sparsify_matrix
204
- self.background_model = background_model
205
-
206
- def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
279
+ def fit(
280
+ self,
281
+ prior_distributions: HaikuDict[Distribution],
282
+ rng_key: int = 0,
283
+ num_iter_max: int = 10_000,
284
+ num_samples: int = 1_000,
285
+ ) -> FitResult:
207
286
  """
208
- Build the numpyro model for the Bayesian fit. It returns a callable which can be used
209
- to fit the model using numpyro's various samplers.
287
+ Fit the model to the data using L-BFGS algorithm.
210
288
 
211
289
  Parameters:
212
290
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
213
- """
214
-
215
- def model(observed=True):
216
- prior_params = build_prior(prior_distributions)
217
- obs_model = build_numpyro_model(self.observation, self.model, self.background_model, sparse=self.sparse)
218
- obs_model(prior_params, observed=observed)
219
-
220
- return model
291
+ rng_key: the random key used to initialize the sampler.
292
+ num_iter_max: the maximum number of iteration in the minimization algorithm.
293
+ num_samples: the number of sample to draw from the best-fit covariance.
221
294
 
295
+ Returns:
296
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
297
+ """
222
298
 
223
- """
224
- class MultipleObservationMCMC(BayesianModelAbstract):
299
+ bayesian_model = self.numpyro_model(prior_distributions)
225
300
 
226
- def __init__(self, model, observations, background_model: BackgroundModel = None):
301
+ param_info, potential_fn, postprocess_fn, *_ = initialize_model(
302
+ PRNGKey(0),
303
+ bayesian_model,
304
+ model_args=tuple(),
305
+ dynamic_args=True, # <- this is important!
306
+ )
227
307
 
228
- super().__init__(model)
229
- self.observations = observations
230
- self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
231
- self.background_model = background_model
308
+ # get negative log-density from the potential function
309
+ @jax.jit
310
+ def nll_fn(position):
311
+ func = potential_fn()
312
+ return func(position)
232
313
 
233
- def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
314
+ solver = jaxopt.LBFGS(fun=nll_fn, maxiter=10_000)
315
+ params, state = solver.run(param_info.z)
316
+ keys = random.split(random.PRNGKey(rng_key), 3)
234
317
 
235
- def model(observed=True):
318
+ value_flat, unflatten_fun = ravel_pytree(params)
319
+ covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
236
320
 
237
- prior_params = build_prior(prior_distributions, expand_shape=(len(self.observations),))
321
+ samples_flat = jax.random.multivariate_normal(keys[0], value_flat, covariance, shape=(num_samples,))
322
+ samples = jax.vmap(unflatten_fun)(samples_flat.block_until_ready())
323
+ posterior_samples = postprocess_fn()(samples)
238
324
 
239
- for i, (key, observation) in enumerate(self.observations.items()):
325
+ posterior_predictive = Predictive(bayesian_model, posterior_samples)(keys[1], observed=False)
326
+ prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
327
+ log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
240
328
 
241
- params = tree_map(lambda x: x[i], prior_params)
329
+ def sanitize_chain(chain):
330
+ return tree_map(lambda x: x[None, ...], chain)
242
331
 
243
- obs_model = build_numpyro_model(
244
- observation,
245
- self.model,
246
- self.background_model,
247
- name=key + '_'
248
- )
332
+ inference_data = az.from_dict(
333
+ sanitize_chain(posterior_samples),
334
+ prior=sanitize_chain(prior),
335
+ posterior_predictive=sanitize_chain(posterior_predictive),
336
+ log_likelihood=sanitize_chain(log_likelihood),
337
+ )
249
338
 
250
- obs_model(params, observed=observed)
339
+ inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
251
340
 
252
- return model
253
- """
341
+ return FitResult(
342
+ self.model,
343
+ self._observation_container,
344
+ inference_data,
345
+ self.model.params,
346
+ background_model=self.background_model,
347
+ )
File without changes