jaxspec 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import operator
2
+ import warnings
2
3
 
3
4
  from abc import ABC, abstractmethod
4
5
  from collections.abc import Callable
@@ -9,29 +10,27 @@ import arviz as az
9
10
  import haiku as hk
10
11
  import jax
11
12
  import jax.numpy as jnp
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
12
15
  import numpyro
13
- import optimistix as optx
14
16
 
15
17
  from jax import random
16
18
  from jax.experimental.sparse import BCOO
17
- from jax.flatten_util import ravel_pytree
18
19
  from jax.random import PRNGKey
19
20
  from jax.tree_util import tree_map
20
21
  from jax.typing import ArrayLike
21
22
  from numpyro.contrib.nested_sampling import NestedSampler
22
23
  from numpyro.distributions import Distribution, Poisson, TransformedDistribution
23
- from numpyro.infer import MCMC, NUTS, Predictive
24
- from numpyro.infer.initialization import init_to_value
24
+ from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
25
25
  from numpyro.infer.inspect import get_model_relations
26
26
  from numpyro.infer.reparam import TransformReparam
27
- from numpyro.infer.util import constrain_fn, log_density
28
- from scipy.stats import Covariance, multivariate_normal
27
+ from numpyro.infer.util import log_density
29
28
 
29
+ from .analysis._plot import _plot_poisson_data_with_error
30
30
  from .analysis.results import FitResult
31
31
  from .data import ObsConfiguration
32
32
  from .model.abc import SpectralModel
33
33
  from .model.background import BackgroundModel
34
- from .util import catchtime
35
34
  from .util.typing import PriorDictModel, PriorDictType
36
35
 
37
36
 
@@ -100,27 +99,6 @@ def build_numpyro_model_for_single_obs(
100
99
  return numpyro_model
101
100
 
102
101
 
103
- def filter_inference_data(
104
- inference_data, observation_container, background_model=None
105
- ) -> az.InferenceData:
106
- predictive_parameters = []
107
-
108
- for key, value in observation_container.items():
109
- if background_model is not None:
110
- predictive_parameters.append(f"obs_{key}")
111
- predictive_parameters.append(f"bkg_{key}")
112
- else:
113
- predictive_parameters.append(f"obs_{key}")
114
-
115
- inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
116
-
117
- parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
118
- inference_data.posterior = inference_data.posterior[parameters]
119
- inference_data.prior = inference_data.prior[parameters]
120
-
121
- return inference_data
122
-
123
-
124
102
  class CountForwardModel(hk.Module):
125
103
  """
126
104
  A haiku module which allows to build the function that simulates the measured counts
@@ -153,7 +131,8 @@ class CountForwardModel(hk.Module):
153
131
 
154
132
  class BayesianModel:
155
133
  """
156
- Class to fit a model to a given set of observation.
134
+ Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
135
+ and compute the log-likelihood and posterior probability.
157
136
  """
158
137
 
159
138
  def __init__(
@@ -165,6 +144,8 @@ class BayesianModel:
165
144
  sparsify_matrix: bool = False,
166
145
  ):
167
146
  """
147
+ Build a Bayesian model for a given spectral model and observations.
148
+
168
149
  Parameters:
169
150
  model: the spectral model to fit.
170
151
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
@@ -181,7 +162,7 @@ class BayesianModel:
181
162
 
182
163
  if not callable(prior_distributions):
183
164
  # Validate the entry with pydantic
184
- prior = PriorDictModel(nested_dict=prior_distributions).nested_dict
165
+ prior = PriorDictModel.from_dict(prior_distributions).nested_dict
185
166
 
186
167
  def prior_distributions_func():
187
168
  return build_prior(prior, expand_shape=(len(self.observation_container),))
@@ -190,7 +171,7 @@ class BayesianModel:
190
171
  prior_distributions_func = prior_distributions
191
172
 
192
173
  self.prior_distributions_func = prior_distributions_func
193
- self.init_params = self.get_initial_params()
174
+ self.init_params = self.prior_samples()
194
175
 
195
176
  @cached_property
196
177
  def observation_container(self) -> dict[str, ObsConfiguration]:
@@ -214,9 +195,6 @@ class BayesianModel:
214
195
  def numpyro_model(self) -> Callable:
215
196
  """
216
197
  Build the numpyro model using the observed data, the prior distributions and the spectral model.
217
-
218
- Returns:
219
- A model function that can be used with numpyro.
220
198
  """
221
199
 
222
200
  def model(observed=True):
@@ -256,9 +234,6 @@ class BayesianModel:
256
234
  def log_likelihood_per_obs(self) -> Callable:
257
235
  """
258
236
  Build the log likelihood function for each bins in each observation.
259
-
260
- Returns:
261
- Callable log-likelihood function.
262
237
  """
263
238
 
264
239
  @jax.jit
@@ -293,6 +268,9 @@ class BayesianModel:
293
268
  that can be fetched with the [`parameter_names`][jaxspec.fit.BayesianModel.parameter_names].
294
269
  """
295
270
 
271
+ # This is required as numpyro.infer.util.log_densities does not check parameter validity by itself
272
+ numpyro.enable_validation()
273
+
296
274
  @jax.jit
297
275
  def log_posterior_prob(constrained_params):
298
276
  log_posterior_prob, _ = log_density(
@@ -312,6 +290,16 @@ class BayesianModel:
312
290
  observed_sites = relations["observed"]
313
291
  return [site for site in all_sites if site not in observed_sites]
314
292
 
293
+ @cached_property
294
+ def observation_names(self) -> list[str]:
295
+ """
296
+ List of the observations.
297
+ """
298
+ relations = get_model_relations(self.numpyro_model)
299
+ all_sites = relations["sample_sample"].keys()
300
+ observed_sites = relations["observed"]
301
+ return [site for site in all_sites if site in observed_sites]
302
+
315
303
  def array_to_dict(self, theta):
316
304
  """
317
305
  Convert an array of parameters to a dictionary of parameters.
@@ -335,7 +323,7 @@ class BayesianModel:
335
323
 
336
324
  return theta
337
325
 
338
- def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1):
326
+ def prior_samples(self, key: PRNGKey = PRNGKey(0), num_samples: int = 100):
339
327
  """
340
328
  Get initial parameters for the model by sampling from the prior distribution
341
329
 
@@ -344,242 +332,264 @@ class BayesianModel:
344
332
  num_samples: the number of samples to draw from the prior.
345
333
  """
346
334
 
347
- return Predictive(
348
- self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
349
- )(key, observed=False)
335
+ @jax.jit
336
+ def prior_sample(key):
337
+ return Predictive(
338
+ self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
339
+ )(key, observed=False)
350
340
 
341
+ return prior_sample(key)
351
342
 
352
- class BayesianModelFitter(BayesianModel, ABC):
353
- @abstractmethod
354
- def fit(self, **kwargs) -> FitResult: ...
343
+ def mock_observations(self, parameters, key: PRNGKey = PRNGKey(0)):
344
+ @jax.jit
345
+ def fakeit(key, parameters):
346
+ return Predictive(
347
+ self.numpyro_model,
348
+ return_sites=self.observation_names,
349
+ posterior_samples=parameters,
350
+ )(key, observed=False)
355
351
 
352
+ return fakeit(key, parameters)
356
353
 
357
- class NUTSFitter(BayesianModelFitter):
358
- """
359
- A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
360
- from numpyro to perform the inference on the model parameters.
361
- """
354
+ def prior_predictive_coverage(
355
+ self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
356
+ ):
357
+ """
358
+ Check if the prior distribution include the observed data.
359
+ """
360
+ key_prior, key_posterior = jax.random.split(key, 2)
361
+ prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
362
+ posterior_observations = self.mock_observations(prior_params, key=key_posterior)
362
363
 
363
- def fit(
364
+ for key, value in self.observation_container.items():
365
+ fig, axs = plt.subplots(
366
+ nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
367
+ )
368
+
369
+ _plot_poisson_data_with_error(
370
+ axs[0],
371
+ value.out_energies,
372
+ value.folded_counts.values,
373
+ percentiles=percentiles,
374
+ )
375
+
376
+ axs[0].stairs(
377
+ np.max(posterior_observations["obs_" + key], axis=0),
378
+ edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
379
+ baseline=np.min(posterior_observations["obs_" + key], axis=0),
380
+ alpha=0.3,
381
+ fill=True,
382
+ color=(0.15, 0.25, 0.45),
383
+ )
384
+
385
+ # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
386
+ counts = posterior_observations["obs_" + key]
387
+ observed = value.folded_counts.values
388
+
389
+ num_samples = counts.shape[0]
390
+
391
+ less_than_obs = (counts < observed).sum(axis=0)
392
+ equal_to_obs = (counts == observed).sum(axis=0)
393
+
394
+ rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
395
+
396
+ axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
397
+
398
+ axs[1].plot(
399
+ (value.out_energies.min(), value.out_energies.max()),
400
+ (50, 50),
401
+ color="black",
402
+ linestyle="--",
403
+ )
404
+
405
+ axs[1].set_xlabel("Energy (keV)")
406
+ axs[0].set_ylabel("Counts")
407
+ axs[1].set_ylabel("Rank (%)")
408
+ axs[1].set_ylim(0, 100)
409
+ axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
410
+ axs[0].loglog()
411
+ plt.suptitle(f"Prior Predictive coverage for {key}")
412
+ plt.show()
413
+
414
+
415
+ class BayesianModelFitter(BayesianModel, ABC):
416
+ def build_inference_data(
364
417
  self,
365
- rng_key: int = 0,
366
- num_chains: int = len(jax.devices()),
367
- num_warmup: int = 1000,
368
- num_samples: int = 1000,
369
- max_tree_depth: int = 10,
370
- target_accept_prob: float = 0.8,
371
- dense_mass: bool = False,
372
- kernel_kwargs: dict = {},
373
- mcmc_kwargs: dict = {},
374
- ) -> FitResult:
418
+ posterior_samples,
419
+ num_chains: int = 1,
420
+ num_predictive_samples: int = 1000,
421
+ key: PRNGKey = PRNGKey(42),
422
+ use_transformed_model: bool = False,
423
+ filter_inference_data: bool = True,
424
+ ) -> az.InferenceData:
375
425
  """
376
- Fit the model to the data using NUTS sampler from numpyro.
426
+ Build an [InferenceData][arviz.InferenceData] object from posterior samples.
377
427
 
378
428
  Parameters:
379
- rng_key: the random key used to initialize the sampler.
380
- num_chains: the number of chains to run.
381
- num_warmup: the number of warmup steps.
382
- num_samples: the number of samples to draw.
383
- max_tree_depth: the recursion depth of NUTS sampler.
384
- target_accept_prob: the target acceptance probability for the NUTS sampler.
385
- dense_mass: whether to use a dense mass for the NUTS sampler.
386
- mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
387
-
388
- Returns:
389
- A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
429
+ posterior_samples: the samples from the posterior distribution.
430
+ num_chains: the number of chains used to sample the posterior.
431
+ num_predictive_samples: the number of samples to draw from the prior.
432
+ key: the random key used to initialize the sampler.
433
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
434
+ filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
390
435
  """
391
436
 
392
- bayesian_model = self.transformed_numpyro_model
393
- # bayesian_model = self.numpyro_model(prior_distributions)
394
-
395
- chain_kwargs = {
396
- "num_warmup": num_warmup,
397
- "num_samples": num_samples,
398
- "num_chains": num_chains,
399
- }
400
-
401
- kernel = NUTS(
402
- bayesian_model,
403
- max_tree_depth=max_tree_depth,
404
- target_accept_prob=target_accept_prob,
405
- dense_mass=dense_mass,
406
- **kernel_kwargs,
437
+ numpyro_model = (
438
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
407
439
  )
408
440
 
409
- mcmc = MCMC(kernel, **(chain_kwargs | mcmc_kwargs))
410
- keys = random.split(random.PRNGKey(rng_key), 3)
441
+ keys = random.split(key, 3)
411
442
 
412
- mcmc.run(keys[0])
443
+ posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
413
444
 
414
- posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
445
+ prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
415
446
  keys[1], observed=False
416
447
  )
417
448
 
418
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
449
+ log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
419
450
 
420
- inference_data = az.from_numpyro(
421
- mcmc, prior=prior, posterior_predictive=posterior_predictive
451
+ seeded_model = numpyro.handlers.substitute(
452
+ numpyro.handlers.seed(numpyro_model, keys[3]),
453
+ substitute_fn=numpyro.infer.init_to_sample,
422
454
  )
423
455
 
424
- inference_data = filter_inference_data(
425
- inference_data, self.observation_container, self.background_model
426
- )
456
+ observations = {
457
+ name: site["value"]
458
+ for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
459
+ if site["type"] == "sample" and site["is_observed"]
460
+ }
427
461
 
428
- return FitResult(
429
- self,
430
- inference_data,
431
- self.model.params,
432
- background_model=self.background_model,
433
- )
462
+ def reshape_first_dimension(arr):
463
+ new_dim = arr.shape[0] // num_chains
464
+ new_shape = (num_chains, new_dim) + arr.shape[1:]
465
+ reshaped_array = arr.reshape(new_shape)
434
466
 
467
+ return reshaped_array
435
468
 
436
- class MinimizationFitter(BayesianModelFitter):
437
- """
438
- A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
439
- algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
440
- Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
441
- numpyro.
442
- """
469
+ posterior_samples = {
470
+ key: reshape_first_dimension(value) for key, value in posterior_samples.items()
471
+ }
472
+ prior = {key: value[None, :] for key, value in prior.items()}
473
+ posterior_predictive = {
474
+ key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
475
+ }
476
+ log_likelihood = {
477
+ key: reshape_first_dimension(value) for key, value in log_likelihood.items()
478
+ }
443
479
 
444
- def fit(
445
- self,
446
- rng_key: int = 0,
447
- num_iter_max: int = 100_000,
448
- num_samples: int = 1_000,
449
- solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
450
- init_params=None,
451
- refine_first_guess=True,
452
- ) -> FitResult:
453
- """
454
- Fit the model to the data using L-BFGS algorithm.
480
+ inference_data = az.from_dict(
481
+ posterior_samples,
482
+ prior=prior,
483
+ posterior_predictive=posterior_predictive,
484
+ log_likelihood=log_likelihood,
485
+ observed_data=observations,
486
+ )
455
487
 
456
- Parameters:
457
- rng_key: the random key used to initialize the sampler.
458
- num_iter_max: the maximum number of iteration in the minimization algorithm.
459
- num_samples: the number of sample to draw from the best-fit covariance.
488
+ return (
489
+ self.filter_inference_data(inference_data) if filter_inference_data else inference_data
490
+ )
460
491
 
461
- Returns:
462
- A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
492
+ def filter_inference_data(
493
+ self,
494
+ inference_data: az.InferenceData,
495
+ ) -> az.InferenceData:
463
496
  """
497
+ Filter the inference data to keep only the relevant parameters for the observations.
464
498
 
465
- bayesian_model = self.numpyro_model
466
- keys = jax.random.split(PRNGKey(rng_key), 4)
499
+ - Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
500
+ - Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
501
+ """
467
502
 
468
- if init_params is not None:
469
- # We initialize the parameters by randomly sampling from the prior
470
- local_keys = jax.random.split(keys[0], 2)
503
+ predictive_parameters = []
471
504
 
472
- with numpyro.handlers.seed(rng_seed=local_keys[0]):
473
- starting_value = self.prior_distributions_func()
505
+ for key, value in self.observation_container.items():
506
+ if self.background_model is not None:
507
+ predictive_parameters.append(f"obs_{key}")
508
+ predictive_parameters.append(f"bkg_{key}")
509
+ else:
510
+ predictive_parameters.append(f"obs_{key}")
474
511
 
475
- # We update the starting value with the provided init_params
476
- for m, n, val in hk.data_structures.traverse(init_params):
477
- if f"{m}_{n}" in starting_value.keys():
478
- starting_value[f"{m}_{n}"] = val
512
+ inference_data.posterior_predictive = inference_data.posterior_predictive[
513
+ predictive_parameters
514
+ ]
479
515
 
480
- init_params, _ = numpyro.infer.util.find_valid_initial_params(
481
- local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
482
- )
516
+ parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
517
+ inference_data.posterior = inference_data.posterior[parameters]
518
+ inference_data.prior = inference_data.prior[parameters]
483
519
 
484
- else:
485
- init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
520
+ return inference_data
486
521
 
487
- init_params = init_params[0]
522
+ @abstractmethod
523
+ def fit(self, **kwargs) -> FitResult: ...
488
524
 
489
- @jax.jit
490
- def nll(unconstrained_params, _):
491
- constrained_params = constrain_fn(
492
- bayesian_model, tuple(), dict(observed=True), unconstrained_params
493
- )
494
525
 
495
- log_likelihood = numpyro.infer.util.log_likelihood(
496
- model=bayesian_model, posterior_samples=constrained_params
497
- )
526
+ class MCMCFitter(BayesianModelFitter):
527
+ """
528
+ A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
529
+ from numpyro to perform the inference on the model parameters.
530
+ """
498
531
 
499
- # We solve a least square problem, this function ensure that the total residual is indeed the nll
500
- return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
532
+ kernel_dict = {
533
+ "nuts": NUTS,
534
+ "aies": AIES,
535
+ "ess": ESS,
536
+ }
501
537
 
538
+ def fit(
539
+ self,
540
+ rng_key: int = 0,
541
+ num_chains: int = len(jax.devices()),
542
+ num_warmup: int = 1000,
543
+ num_samples: int = 1000,
544
+ sampler: Literal["nuts", "aies", "ess"] = "nuts",
545
+ use_transformed_model: bool = True,
546
+ kernel_kwargs: dict = {},
547
+ mcmc_kwargs: dict = {},
548
+ ) -> FitResult:
502
549
  """
503
- if refine_first_guess:
504
- with catchtime("Refine_first"):
505
- solution = optx.least_squares(
506
- nll,
507
- optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
508
- init_params,
509
- max_steps=1000,
510
- throw=False
511
- )
512
- init_params = solution.value
513
- """
550
+ Fit the model to the data using a MCMC sampler from numpyro.
514
551
 
515
- if solver == "bfgs":
516
- solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
517
- elif solver == "levenberg_marquardt":
518
- solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
519
- else:
520
- raise NotImplementedError(f"{solver} is not implemented")
521
-
522
- with catchtime("Minimization"):
523
- solution = optx.least_squares(
524
- nll,
525
- solver,
526
- init_params,
527
- max_steps=num_iter_max,
528
- )
552
+ Parameters:
553
+ rng_key: the random key used to initialize the sampler.
554
+ num_chains: the number of chains to run.
555
+ num_warmup: the number of warmup steps.
556
+ num_samples: the number of samples to draw.
557
+ sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
558
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
559
+ kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
560
+ mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
529
561
 
530
- params = solution.value
531
- value_flat, unflatten_fun = ravel_pytree(params)
562
+ Returns:
563
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
564
+ """
532
565
 
533
- with catchtime("Compute error"):
534
- precision = jax.hessian(
535
- lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
536
- )(value_flat)
566
+ bayesian_model = (
567
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
568
+ )
537
569
 
538
- cov = Covariance.from_precision(precision)
570
+ chain_kwargs = {
571
+ "num_warmup": num_warmup,
572
+ "num_samples": num_samples,
573
+ "num_chains": num_chains,
574
+ }
539
575
 
540
- samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
576
+ kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
541
577
 
542
- samples = jax.vmap(unflatten_fun)(samples_flat)
543
- posterior_samples = jax.jit(
544
- jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
545
- )(samples)
578
+ mcmc_kwargs = chain_kwargs | mcmc_kwargs
546
579
 
547
- with catchtime("Posterior"):
548
- posterior_predictive = Predictive(bayesian_model, posterior_samples)(
549
- keys[2], observed=False
550
- )
551
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
552
- log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
580
+ if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
581
+ mcmc_kwargs["chain_method"] = "vectorized"
582
+ warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
553
583
 
554
- def sanitize_chain(chain):
555
- """
556
- reshape the samples so that it is arviz compliant with an extra starting dimension
557
- """
558
- return tree_map(lambda x: x[None, ...], chain)
584
+ mcmc = MCMC(kernel, **mcmc_kwargs)
585
+ keys = random.split(random.PRNGKey(rng_key), 3)
559
586
 
560
- # We export the observed values to the inference_data
561
- seeded_model = numpyro.handlers.substitute(
562
- numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
563
- substitute_fn=numpyro.infer.init_to_sample,
564
- )
565
- trace = numpyro.handlers.trace(seeded_model).get_trace()
566
- observations = {
567
- name: site["value"]
568
- for name, site in trace.items()
569
- if site["type"] == "sample" and site["is_observed"]
570
- }
587
+ mcmc.run(keys[0])
571
588
 
572
- with catchtime("InferenceData wrapping"):
573
- inference_data = az.from_dict(
574
- sanitize_chain(posterior_samples),
575
- prior=sanitize_chain(prior),
576
- posterior_predictive=sanitize_chain(posterior_predictive),
577
- log_likelihood=sanitize_chain(log_likelihood),
578
- observed_data=observations,
579
- )
589
+ posterior = mcmc.get_samples()
580
590
 
581
- inference_data = filter_inference_data(
582
- inference_data, self.observation_container, self.background_model
591
+ inference_data = self.build_inference_data(
592
+ posterior, num_chains=num_chains, use_transformed_model=True
583
593
  )
584
594
 
585
595
  return FitResult(
@@ -590,18 +600,22 @@ class MinimizationFitter(BayesianModelFitter):
590
600
  )
591
601
 
592
602
 
593
- class NestedSamplingFitter(BayesianModelFitter):
603
+ class NSFitter(BayesianModelFitter):
594
604
  r"""
595
605
  A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
596
606
  [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
597
607
  implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
598
- Add Citation to jaxns
608
+
609
+ !!! info
610
+ Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
611
+
599
612
  """
600
613
 
601
614
  def fit(
602
615
  self,
603
616
  rng_key: int = 0,
604
617
  num_samples: int = 1000,
618
+ num_live_points: int = 1000,
605
619
  plot_diagnostics=False,
606
620
  termination_kwargs: dict | None = None,
607
621
  verbose=True,
@@ -612,6 +626,10 @@ class NestedSamplingFitter(BayesianModelFitter):
612
626
  Parameters:
613
627
  rng_key: the random key used to initialize the sampler.
614
628
  num_samples: the number of samples to draw.
629
+ num_live_points: the number of live points to use at the start of the NS algorithm.
630
+ plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
631
+ termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
632
+ verbose: whether to print the progress of the NS algorithm.
615
633
 
616
634
  Returns:
617
635
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
@@ -628,7 +646,7 @@ class NestedSamplingFitter(BayesianModelFitter):
628
646
  difficult_model=True,
629
647
  max_samples=1e6,
630
648
  parameter_estimation=True,
631
- num_live_points=1_000,
649
+ num_live_points=num_live_points,
632
650
  ),
633
651
  termination_kwargs=termination_kwargs if termination_kwargs else dict(),
634
652
  )
@@ -638,41 +656,9 @@ class NestedSamplingFitter(BayesianModelFitter):
638
656
  if plot_diagnostics:
639
657
  ns.diagnostics()
640
658
 
641
- posterior_samples = ns.get_samples(keys[1], num_samples=num_samples)
642
- log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
643
- posterior_predictive = Predictive(bayesian_model, posterior_samples)(
644
- keys[2], observed=False
645
- )
646
-
647
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
648
-
649
- seeded_model = numpyro.handlers.substitute(
650
- numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
651
- substitute_fn=numpyro.infer.init_to_sample,
652
- )
653
- trace = numpyro.handlers.trace(seeded_model).get_trace()
654
- observations = {
655
- name: site["value"]
656
- for name, site in trace.items()
657
- if site["type"] == "sample" and site["is_observed"]
658
- }
659
-
660
- def sanitize_chain(chain):
661
- """
662
- reshape the samples so that it is arviz compliant with an extra starting dimension
663
- """
664
- return tree_map(lambda x: x[None, ...], chain)
665
-
666
- inference_data = az.from_dict(
667
- sanitize_chain(posterior_samples),
668
- prior=sanitize_chain(prior),
669
- posterior_predictive=sanitize_chain(posterior_predictive),
670
- log_likelihood=sanitize_chain(log_likelihood),
671
- observed_data=observations,
672
- )
673
-
674
- inference_data = filter_inference_data(
675
- inference_data, self.observation_container, self.background_model
659
+ posterior = ns.get_samples(keys[1], num_samples=num_samples)
660
+ inference_data = self.build_inference_data(
661
+ posterior, num_chains=1, use_transformed_model=True
676
662
  )
677
663
 
678
664
  return FitResult(