jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -1,67 +1,84 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import arviz as az
1
6
  import haiku as hk
7
+ import jax
2
8
  import jax.numpy as jnp
3
9
  import numpyro
4
- import arviz as az
5
- import jax
6
- from typing import Callable, TypeVar
7
- from abc import ABC, abstractmethod
10
+ import optimistix as optx
11
+
8
12
  from jax import random
9
- from jax.tree_util import tree_map
13
+ from jax.experimental.sparse import BCOO
10
14
  from jax.flatten_util import ravel_pytree
11
- from jax.experimental.sparse import BCSR
12
- from .analysis.results import FitResult
13
- from .model.abc import SpectralModel
14
- from .data import ObsConfiguration
15
- from .model.background import BackgroundModel
16
- from numpyro.infer import MCMC, NUTS, Predictive
17
- from numpyro.distributions import Distribution, TransformedDistribution
18
- from numpyro.distributions import Poisson
15
+ from jax.random import PRNGKey
16
+ from jax.tree_util import tree_map
19
17
  from jax.typing import ArrayLike
18
+ from numpyro.contrib.nested_sampling import NestedSampler
19
+ from numpyro.distributions import Distribution, Poisson, TransformedDistribution
20
+ from numpyro.infer import MCMC, NUTS, Predictive
21
+ from numpyro.infer.initialization import init_to_value
22
+ from numpyro.infer.inspect import get_model_relations
20
23
  from numpyro.infer.reparam import TransformReparam
21
- from numpyro.infer.util import initialize_model
22
- from jax.random import PRNGKey
23
- import jaxopt
24
+ from numpyro.infer.util import constrain_fn
25
+ from scipy.stats import Covariance, multivariate_normal
24
26
 
25
-
26
- T = TypeVar("T")
27
-
28
-
29
- class HaikuDict(dict[str, dict[str, T]]): ...
27
+ from .analysis.results import FitResult
28
+ from .data import ObsConfiguration
29
+ from .model.abc import SpectralModel
30
+ from .model.background import BackgroundModel
31
+ from .util import catchtime
32
+ from .util.typing import PriorDictModel, PriorDictType
30
33
 
31
34
 
32
- def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple = ()):
35
+ def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
36
+ """
37
+ Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
38
+ Must be used within a numpyro model.
39
+ """
33
40
  parameters = dict(hk.data_structures.to_haiku_dict(prior))
34
41
 
35
42
  for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
36
- match sample:
37
- case Distribution():
38
- parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{m}_{n}", sample)
39
- # parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape)) build a free parameter for each obs
40
- case float() | ArrayLike():
41
- parameters[m][n] = jnp.ones(expand_shape) * sample
42
- case _:
43
- raise ValueError(f"Invalid prior type {type(sample)} for parameter {m}_{n} : {sample}")
43
+ if isinstance(sample, Distribution):
44
+ parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
45
+
46
+ elif isinstance(sample, ArrayLike):
47
+ parameters[m][n] = jnp.ones(expand_shape) * sample
48
+
49
+ else:
50
+ raise ValueError(
51
+ f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
52
+ )
44
53
 
45
54
  return parameters
46
55
 
47
56
 
48
- def build_numpyro_model(
57
+ def build_numpyro_model_for_single_obs(
49
58
  obs: ObsConfiguration,
50
59
  model: SpectralModel,
51
60
  background_model: BackgroundModel,
52
61
  name: str = "",
53
62
  sparse: bool = False,
54
63
  ) -> Callable:
55
- def numpro_model(prior_params, observed=True):
64
+ """
65
+ Build a numpyro model for a given observation and spectral model.
66
+ """
67
+
68
+ def numpyro_model(prior_params, observed=True):
56
69
  # prior_params = build_prior(prior_distributions, name=name)
57
- transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par)))
70
+ transformed_model = hk.without_apply_rng(
71
+ hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
72
+ )
58
73
 
59
74
  if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
60
75
  bkg_countrate = background_model.numpyro_model(
61
- obs.out_energies, obs.folded_background.data, name="bkg_" + name, observed=observed
76
+ obs, model, name="bkg_" + name, observed=observed
62
77
  )
63
78
  elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
64
- raise ValueError("Trying to fit a background model but no background is linked to this observation")
79
+ raise ValueError(
80
+ "Trying to fit a background model but no background is linked to this observation"
81
+ )
65
82
 
66
83
  else:
67
84
  bkg_countrate = 0.0
@@ -77,10 +94,12 @@ def build_numpyro_model(
77
94
  obs=obs.folded_counts.data if observed else None,
78
95
  )
79
96
 
80
- return numpro_model
97
+ return numpyro_model
81
98
 
82
99
 
83
- def filter_inference_data(inference_data, observation_container, background_model=None) -> az.InferenceData:
100
+ def filter_inference_data(
101
+ inference_data, observation_container, background_model=None
102
+ ) -> az.InferenceData:
84
103
  predictive_parameters = []
85
104
 
86
105
  for key, value in observation_container.items():
@@ -109,8 +128,12 @@ class CountForwardModel(hk.Module):
109
128
  self.model = model
110
129
  self.energies = jnp.asarray(folding.in_energies)
111
130
 
112
- if sparse: # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
113
- self.transfer_matrix = BCSR.from_scipy_sparse(folding.transfer_matrix.data.to_scipy_sparse().tocsr()) #
131
+ if (
132
+ sparse
133
+ ): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
134
+ self.transfer_matrix = BCOO.from_scipy_sparse(
135
+ folding.transfer_matrix.data.to_scipy_sparse().tocsr()
136
+ )
114
137
 
115
138
  else:
116
139
  self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
@@ -125,14 +148,15 @@ class CountForwardModel(hk.Module):
125
148
  return jnp.clip(expected_counts, a_min=1e-6)
126
149
 
127
150
 
128
- class ModelFitter(ABC):
151
+ class BayesianModel:
129
152
  """
130
- Abstract class to fit a model to a given set of observation.
153
+ Class to fit a model to a given set of observation.
131
154
  """
132
155
 
133
156
  def __init__(
134
157
  self,
135
158
  model: SpectralModel,
159
+ prior_distributions: PriorDictType | Callable,
136
160
  observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
137
161
  background_model: BackgroundModel = None,
138
162
  sparsify_matrix: bool = False,
@@ -142,6 +166,8 @@ class ModelFitter(ABC):
142
166
 
143
167
  Parameters:
144
168
  model: the spectral model to fit.
169
+ prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
170
+ callable function that returns parameter samples.
145
171
  observations: the observations to fit the model to.
146
172
  background_model: the background model to fit.
147
173
  sparsify_matrix: whether to sparsify the transfer matrix.
@@ -152,8 +178,20 @@ class ModelFitter(ABC):
152
178
  self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
153
179
  self.sparse = sparsify_matrix
154
180
 
181
+ if not callable(prior_distributions):
182
+ # Validate the entry with pydantic
183
+ prior = PriorDictModel(nested_dict=prior_distributions).nested_dict
184
+
185
+ def prior_distributions_func():
186
+ return build_prior(prior, expand_shape=(len(self.observation_container),))
187
+
188
+ else:
189
+ prior_distributions_func = prior_distributions
190
+
191
+ self.prior_distributions_func = prior_distributions_func
192
+
155
193
  @property
156
- def _observation_container(self) -> dict[str, ObsConfiguration]:
194
+ def observation_container(self) -> dict[str, ObsConfiguration]:
157
195
  """
158
196
  The observations used in the fit as a dictionary of observations.
159
197
  """
@@ -170,33 +208,56 @@ class ModelFitter(ABC):
170
208
  else:
171
209
  raise ValueError(f"Invalid type for observations : {type(self._observations)}")
172
210
 
173
- def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
211
+ @property
212
+ def numpyro_model(self) -> Callable:
174
213
  """
175
214
  Build the numpyro model using the observed data, the prior distributions and the spectral model.
176
215
 
177
- Parameters:
178
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
179
-
180
216
  Returns:
217
+ -------
181
218
  A model function that can be used with numpyro.
182
219
  """
183
220
 
184
221
  def model(observed=True):
185
- prior_params = build_prior(prior_distributions, expand_shape=(len(self._observation_container),))
222
+ prior_params = self.prior_distributions_func()
186
223
 
187
- for i, (key, observation) in enumerate(self._observation_container.items()):
224
+ # Iterate over all the observations in our container and build a single numpyro model for each observation
225
+ for i, (key, observation) in enumerate(self.observation_container.items()):
226
+ # We expect that prior_params contains an array of parameters for each observation
227
+ # They can be identical or different for each observation
188
228
  params = tree_map(lambda x: x[i], prior_params)
189
229
 
190
- obs_model = build_numpyro_model(observation, self.model, self.background_model, name=key, sparse=self.sparse)
230
+ obs_model = build_numpyro_model_for_single_obs(
231
+ observation, self.model, self.background_model, name=key, sparse=self.sparse
232
+ )
233
+
191
234
  obs_model(params, observed=observed)
192
235
 
193
236
  return model
194
237
 
238
+ @property
239
+ def transformed_numpyro_model(self) -> Callable:
240
+ transform_dict = {}
241
+
242
+ relations = get_model_relations(self.numpyro_model)
243
+ distributions = {
244
+ parameter: getattr(numpyro.distributions, value, None)
245
+ for parameter, value in relations["sample_dist"].items()
246
+ }
247
+
248
+ for parameter, distribution in distributions.items():
249
+ if isinstance(distribution, TransformedDistribution):
250
+ transform_dict[parameter] = TransformReparam()
251
+
252
+ return numpyro.handlers.reparam(self.numpyro_model, config=transform_dict)
253
+
254
+
255
+ class BayesianModelFitter(BayesianModel, ABC):
195
256
  @abstractmethod
196
- def fit(self, prior_distributions: HaikuDict[Distribution], **kwargs) -> FitResult: ...
257
+ def fit(self, **kwargs) -> FitResult: ...
197
258
 
198
259
 
199
- class BayesianFitter(ModelFitter):
260
+ class NUTSFitter(BayesianModelFitter):
200
261
  """
201
262
  A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
202
263
  from numpyro to perform the inference on the model parameters.
@@ -204,21 +265,20 @@ class BayesianFitter(ModelFitter):
204
265
 
205
266
  def fit(
206
267
  self,
207
- prior_distributions: HaikuDict[Distribution],
208
268
  rng_key: int = 0,
209
- num_chains: int = 4,
269
+ num_chains: int = len(jax.devices()),
210
270
  num_warmup: int = 1000,
211
271
  num_samples: int = 1000,
212
272
  max_tree_depth: int = 10,
213
273
  target_accept_prob: float = 0.8,
214
274
  dense_mass: bool = False,
275
+ kernel_kwargs: dict = {},
215
276
  mcmc_kwargs: dict = {},
216
277
  ) -> FitResult:
217
278
  """
218
279
  Fit the model to the data using NUTS sampler from numpyro.
219
280
 
220
281
  Parameters:
221
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
222
282
  rng_key: the random key used to initialize the sampler.
223
283
  num_chains: the number of chains to run.
224
284
  num_warmup: the number of warmup steps.
@@ -232,11 +292,8 @@ class BayesianFitter(ModelFitter):
232
292
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
233
293
  """
234
294
 
235
- transform_dict = {}
236
-
237
- for m, n, val in hk.data_structures.traverse(prior_distributions):
238
- if isinstance(val, TransformedDistribution):
239
- transform_dict[f"{m}_{n}"] = TransformReparam()
295
+ bayesian_model = self.transformed_numpyro_model
296
+ # bayesian_model = self.numpyro_model(prior_distributions)
240
297
 
241
298
  chain_kwargs = {
242
299
  "num_warmup": num_warmup,
@@ -244,50 +301,62 @@ class BayesianFitter(ModelFitter):
244
301
  "num_chains": num_chains,
245
302
  }
246
303
 
247
- bayesian_model = numpyro.handlers.reparam(self.numpyro_model(prior_distributions), config=transform_dict)
248
-
249
- kernel = NUTS(bayesian_model, max_tree_depth=max_tree_depth, target_accept_prob=target_accept_prob, dense_mass=dense_mass)
304
+ kernel = NUTS(
305
+ bayesian_model,
306
+ max_tree_depth=max_tree_depth,
307
+ target_accept_prob=target_accept_prob,
308
+ dense_mass=dense_mass,
309
+ **kernel_kwargs,
310
+ )
250
311
 
251
312
  mcmc = MCMC(kernel, **(chain_kwargs | mcmc_kwargs))
252
-
253
313
  keys = random.split(random.PRNGKey(rng_key), 3)
254
314
 
255
315
  mcmc.run(keys[0])
256
- posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(keys[1], observed=False)
316
+
317
+ posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
318
+ keys[1], observed=False
319
+ )
320
+
257
321
  prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
258
- inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
259
322
 
260
- inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
323
+ inference_data = az.from_numpyro(
324
+ mcmc, prior=prior, posterior_predictive=posterior_predictive
325
+ )
326
+
327
+ inference_data = filter_inference_data(
328
+ inference_data, self.observation_container, self.background_model
329
+ )
261
330
 
262
331
  return FitResult(
263
- self.model,
264
- self._observation_container,
332
+ self,
265
333
  inference_data,
266
334
  self.model.params,
267
335
  background_model=self.background_model,
268
336
  )
269
337
 
270
338
 
271
- class MinimizationFitter(ModelFitter):
339
+ class MinimizationFitter(BayesianModelFitter):
272
340
  """
273
341
  A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
274
342
  algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
275
- Hessian of the log-likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
343
+ Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
276
344
  numpyro.
277
345
  """
278
346
 
279
347
  def fit(
280
348
  self,
281
- prior_distributions: HaikuDict[Distribution],
282
349
  rng_key: int = 0,
283
- num_iter_max: int = 10_000,
350
+ num_iter_max: int = 100_000,
284
351
  num_samples: int = 1_000,
352
+ solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
353
+ init_params=None,
354
+ refine_first_guess=True,
285
355
  ) -> FitResult:
286
356
  """
287
357
  Fit the model to the data using L-BFGS algorithm.
288
358
 
289
359
  Parameters:
290
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
291
360
  rng_key: the random key used to initialize the sampler.
292
361
  num_iter_max: the maximum number of iteration in the minimization algorithm.
293
362
  num_samples: the number of sample to draw from the best-fit covariance.
@@ -296,37 +365,205 @@ class MinimizationFitter(ModelFitter):
296
365
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
297
366
  """
298
367
 
299
- bayesian_model = self.numpyro_model(prior_distributions)
368
+ bayesian_model = self.numpyro_model
369
+ keys = jax.random.split(PRNGKey(rng_key), 4)
300
370
 
301
- param_info, potential_fn, postprocess_fn, *_ = initialize_model(
302
- PRNGKey(0),
303
- bayesian_model,
304
- model_args=tuple(),
305
- dynamic_args=True, # <- this is important!
306
- )
371
+ if init_params is not None:
372
+ # We initialize the parameters by randomly sampling from the prior
373
+ local_keys = jax.random.split(keys[0], 2)
374
+
375
+ with numpyro.handlers.seed(rng_seed=local_keys[0]):
376
+ starting_value = self.prior_distributions_func()
377
+
378
+ # We update the starting value with the provided init_params
379
+ for m, n, val in hk.data_structures.traverse(init_params):
380
+ if f"{m}_{n}" in starting_value.keys():
381
+ starting_value[f"{m}_{n}"] = val
382
+
383
+ init_params, _ = numpyro.infer.util.find_valid_initial_params(
384
+ local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
385
+ )
386
+
387
+ else:
388
+ init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
389
+
390
+ init_params = init_params[0]
307
391
 
308
- # get negative log-density from the potential function
309
392
  @jax.jit
310
- def nll_fn(position):
311
- func = potential_fn()
312
- return func(position)
393
+ def nll(unconstrained_params, _):
394
+ constrained_params = constrain_fn(
395
+ bayesian_model, tuple(), dict(observed=True), unconstrained_params
396
+ )
313
397
 
314
- solver = jaxopt.LBFGS(fun=nll_fn, maxiter=10_000)
315
- params, state = solver.run(param_info.z)
316
- keys = random.split(random.PRNGKey(rng_key), 3)
398
+ log_likelihood = numpyro.infer.util.log_likelihood(
399
+ model=bayesian_model, posterior_samples=constrained_params
400
+ )
317
401
 
402
+ # We solve a least square problem, this function ensure that the total residual is indeed the nll
403
+ return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
404
+
405
+ """
406
+ if refine_first_guess:
407
+ with catchtime("Refine_first"):
408
+ solution = optx.least_squares(
409
+ nll,
410
+ optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
411
+ init_params,
412
+ max_steps=1000,
413
+ throw=False
414
+ )
415
+ init_params = solution.value
416
+ """
417
+
418
+ if solver == "bfgs":
419
+ solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
420
+ elif solver == "levenberg_marquardt":
421
+ solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
422
+ else:
423
+ raise NotImplementedError(f"{solver} is not implemented")
424
+
425
+ with catchtime("Minimization"):
426
+ solution = optx.least_squares(
427
+ nll,
428
+ solver,
429
+ init_params,
430
+ max_steps=num_iter_max,
431
+ )
432
+
433
+ params = solution.value
318
434
  value_flat, unflatten_fun = ravel_pytree(params)
319
- covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
320
435
 
321
- samples_flat = jax.random.multivariate_normal(keys[0], value_flat, covariance, shape=(num_samples,))
322
- samples = jax.vmap(unflatten_fun)(samples_flat.block_until_ready())
323
- posterior_samples = postprocess_fn()(samples)
436
+ with catchtime("Compute error"):
437
+ precision = jax.hessian(
438
+ lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
439
+ )(value_flat)
324
440
 
325
- posterior_predictive = Predictive(bayesian_model, posterior_samples)(keys[1], observed=False)
326
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
441
+ cov = Covariance.from_precision(precision)
442
+
443
+ samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
444
+
445
+ samples = jax.vmap(unflatten_fun)(samples_flat)
446
+ posterior_samples = jax.jit(
447
+ jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
448
+ )(samples)
449
+
450
+ with catchtime("Posterior"):
451
+ posterior_predictive = Predictive(bayesian_model, posterior_samples)(
452
+ keys[2], observed=False
453
+ )
454
+ prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
455
+ log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
456
+
457
+ def sanitize_chain(chain):
458
+ """
459
+ reshape the samples so that it is arviz compliant with an extra starting dimension
460
+ """
461
+ return tree_map(lambda x: x[None, ...], chain)
462
+
463
+ # We export the observed values to the inference_data
464
+ seeded_model = numpyro.handlers.substitute(
465
+ numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
466
+ substitute_fn=numpyro.infer.init_to_sample,
467
+ )
468
+ trace = numpyro.handlers.trace(seeded_model).get_trace()
469
+ observations = {
470
+ name: site["value"]
471
+ for name, site in trace.items()
472
+ if site["type"] == "sample" and site["is_observed"]
473
+ }
474
+
475
+ with catchtime("InferenceData wrapping"):
476
+ inference_data = az.from_dict(
477
+ sanitize_chain(posterior_samples),
478
+ prior=sanitize_chain(prior),
479
+ posterior_predictive=sanitize_chain(posterior_predictive),
480
+ log_likelihood=sanitize_chain(log_likelihood),
481
+ observed_data=observations,
482
+ )
483
+
484
+ inference_data = filter_inference_data(
485
+ inference_data, self.observation_container, self.background_model
486
+ )
487
+
488
+ return FitResult(
489
+ self,
490
+ inference_data,
491
+ self.model.params,
492
+ background_model=self.background_model,
493
+ )
494
+
495
+
496
+ class NestedSamplingFitter(BayesianModelFitter):
497
+ r"""
498
+ A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
499
+ [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
500
+ implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
501
+ Add Citation to jaxns
502
+ """
503
+
504
+ def fit(
505
+ self,
506
+ rng_key: int = 0,
507
+ num_samples: int = 1000,
508
+ plot_diagnostics=False,
509
+ termination_kwargs: dict | None = None,
510
+ verbose=True,
511
+ ) -> FitResult:
512
+ """
513
+ Fit the model to the data using the Phantom-Powered nested sampling algorithm.
514
+
515
+ Parameters:
516
+ rng_key: the random key used to initialize the sampler.
517
+ num_samples: the number of samples to draw.
518
+
519
+ Returns:
520
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
521
+ """
522
+
523
+ bayesian_model = self.transformed_numpyro_model
524
+ keys = random.split(random.PRNGKey(rng_key), 4)
525
+
526
+ ns = NestedSampler(
527
+ bayesian_model,
528
+ constructor_kwargs=dict(
529
+ num_parallel_workers=1,
530
+ verbose=verbose,
531
+ difficult_model=True,
532
+ max_samples=1e6,
533
+ parameter_estimation=True,
534
+ num_live_points=1_000,
535
+ ),
536
+ termination_kwargs=termination_kwargs if termination_kwargs else dict(),
537
+ )
538
+
539
+ ns.run(keys[0])
540
+
541
+ if plot_diagnostics:
542
+ ns.diagnostics()
543
+
544
+ posterior_samples = ns.get_samples(keys[1], num_samples=num_samples)
327
545
  log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
546
+ posterior_predictive = Predictive(bayesian_model, posterior_samples)(
547
+ keys[2], observed=False
548
+ )
549
+
550
+ prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
551
+
552
+ seeded_model = numpyro.handlers.substitute(
553
+ numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
554
+ substitute_fn=numpyro.infer.init_to_sample,
555
+ )
556
+ trace = numpyro.handlers.trace(seeded_model).get_trace()
557
+ observations = {
558
+ name: site["value"]
559
+ for name, site in trace.items()
560
+ if site["type"] == "sample" and site["is_observed"]
561
+ }
328
562
 
329
563
  def sanitize_chain(chain):
564
+ """
565
+ reshape the samples so that it is arviz compliant with an extra starting dimension
566
+ """
330
567
  return tree_map(lambda x: x[None, ...], chain)
331
568
 
332
569
  inference_data = az.from_dict(
@@ -334,13 +571,15 @@ class MinimizationFitter(ModelFitter):
334
571
  prior=sanitize_chain(prior),
335
572
  posterior_predictive=sanitize_chain(posterior_predictive),
336
573
  log_likelihood=sanitize_chain(log_likelihood),
574
+ observed_data=observations,
337
575
  )
338
576
 
339
- inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
577
+ inference_data = filter_inference_data(
578
+ inference_data, self.observation_container, self.background_model
579
+ )
340
580
 
341
581
  return FitResult(
342
- self.model,
343
- self._observation_container,
582
+ self,
344
583
  inference_data,
345
584
  self.model.params,
346
585
  background_model=self.background_model,
jaxspec/model/__init__.py CHANGED
@@ -1 +0,0 @@
1
-