jaxspec 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -1,67 +1,84 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import arviz as az
1
6
  import haiku as hk
7
+ import jax
2
8
  import jax.numpy as jnp
3
9
  import numpyro
4
- import arviz as az
5
- import jax
6
- from typing import Callable, TypeVar
7
- from abc import ABC, abstractmethod
10
+ import optimistix as optx
11
+
8
12
  from jax import random
9
- from jax.tree_util import tree_map
13
+ from jax.experimental.sparse import BCOO
10
14
  from jax.flatten_util import ravel_pytree
11
- from jax.experimental.sparse import BCSR
12
- from .analysis.results import FitResult
13
- from .model.abc import SpectralModel
14
- from .data import ObsConfiguration
15
- from .model.background import BackgroundModel
16
- from numpyro.infer import MCMC, NUTS, Predictive
17
- from numpyro.distributions import Distribution, TransformedDistribution
18
- from numpyro.distributions import Poisson
15
+ from jax.random import PRNGKey
16
+ from jax.tree_util import tree_map
19
17
  from jax.typing import ArrayLike
18
+ from numpyro.contrib.nested_sampling import NestedSampler
19
+ from numpyro.distributions import Distribution, Poisson, TransformedDistribution
20
+ from numpyro.infer import MCMC, NUTS, Predictive
21
+ from numpyro.infer.initialization import init_to_value
20
22
  from numpyro.infer.reparam import TransformReparam
21
- from numpyro.infer.util import initialize_model
22
- from jax.random import PRNGKey
23
- import jaxopt
23
+ from numpyro.infer.util import constrain_fn
24
+ from scipy.stats import Covariance, multivariate_normal
24
25
 
25
-
26
- T = TypeVar("T")
27
-
28
-
29
- class HaikuDict(dict[str, dict[str, T]]): ...
26
+ from .analysis.results import FitResult
27
+ from .data import ObsConfiguration
28
+ from .model.abc import SpectralModel
29
+ from .model.background import BackgroundModel
30
+ from .util import catchtime, sample_prior
31
+ from .util.typing import PriorDictModel, PriorDictType
30
32
 
31
33
 
32
- def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple = ()):
34
+ def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
35
+ """
36
+ Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
37
+ Must be used within a numpyro model.
38
+ """
33
39
  parameters = dict(hk.data_structures.to_haiku_dict(prior))
34
40
 
35
41
  for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
36
42
  match sample:
37
43
  case Distribution():
38
- parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{m}_{n}", sample)
39
- # parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape)) build a free parameter for each obs
40
- case float() | ArrayLike():
44
+ parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(
45
+ f"{prefix}{m}_{n}", sample
46
+ )
47
+ case ArrayLike():
41
48
  parameters[m][n] = jnp.ones(expand_shape) * sample
42
49
  case _:
43
- raise ValueError(f"Invalid prior type {type(sample)} for parameter {m}_{n} : {sample}")
50
+ raise ValueError(
51
+ f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
52
+ )
44
53
 
45
54
  return parameters
46
55
 
47
56
 
48
- def build_numpyro_model(
57
+ def build_numpyro_model_for_single_obs(
49
58
  obs: ObsConfiguration,
50
59
  model: SpectralModel,
51
60
  background_model: BackgroundModel,
52
61
  name: str = "",
53
62
  sparse: bool = False,
54
63
  ) -> Callable:
55
- def numpro_model(prior_params, observed=True):
64
+ """
65
+ Build a numpyro model for a given observation and spectral model.
66
+ """
67
+
68
+ def numpyro_model(prior_params, observed=True):
56
69
  # prior_params = build_prior(prior_distributions, name=name)
57
- transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par)))
70
+ transformed_model = hk.without_apply_rng(
71
+ hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
72
+ )
58
73
 
59
74
  if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
60
75
  bkg_countrate = background_model.numpyro_model(
61
- obs.out_energies, obs.folded_background.data, name="bkg_" + name, observed=observed
76
+ obs, model, name="bkg_" + name, observed=observed
62
77
  )
63
78
  elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
64
- raise ValueError("Trying to fit a background model but no background is linked to this observation")
79
+ raise ValueError(
80
+ "Trying to fit a background model but no background is linked to this observation"
81
+ )
65
82
 
66
83
  else:
67
84
  bkg_countrate = 0.0
@@ -73,14 +90,16 @@ def build_numpyro_model(
73
90
  with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
74
91
  numpyro.sample(
75
92
  "obs_" + name,
76
- Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
93
+ Poisson(countrate + bkg_countrate * obs.folded_backratio.data),
77
94
  obs=obs.folded_counts.data if observed else None,
78
95
  )
79
96
 
80
- return numpro_model
97
+ return numpyro_model
81
98
 
82
99
 
83
- def filter_inference_data(inference_data, observation_container, background_model=None) -> az.InferenceData:
100
+ def filter_inference_data(
101
+ inference_data, observation_container, background_model=None
102
+ ) -> az.InferenceData:
84
103
  predictive_parameters = []
85
104
 
86
105
  for key, value in observation_container.items():
@@ -109,8 +128,12 @@ class CountForwardModel(hk.Module):
109
128
  self.model = model
110
129
  self.energies = jnp.asarray(folding.in_energies)
111
130
 
112
- if sparse: # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
113
- self.transfer_matrix = BCSR.from_scipy_sparse(folding.transfer_matrix.data.to_scipy_sparse().tocsr()) #
131
+ if (
132
+ sparse
133
+ ): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
134
+ self.transfer_matrix = BCOO.from_scipy_sparse(
135
+ folding.transfer_matrix.data.to_scipy_sparse().tocsr()
136
+ )
114
137
 
115
138
  else:
116
139
  self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
@@ -140,7 +163,8 @@ class ModelFitter(ABC):
140
163
  """
141
164
  Initialize the fitter.
142
165
 
143
- Parameters:
166
+ Parameters
167
+ ----------
144
168
  model: the spectral model to fit.
145
169
  observations: the observations to fit the model to.
146
170
  background_model: the background model to fit.
@@ -170,30 +194,61 @@ class ModelFitter(ABC):
170
194
  else:
171
195
  raise ValueError(f"Invalid type for observations : {type(self._observations)}")
172
196
 
173
- def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
197
+ def numpyro_model(self, prior_distributions: PriorDictType | Callable) -> Callable:
174
198
  """
175
199
  Build the numpyro model using the observed data, the prior distributions and the spectral model.
176
200
 
177
- Parameters:
201
+ Parameters
202
+ ----------
178
203
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
179
204
 
180
205
  Returns:
206
+ -------
181
207
  A model function that can be used with numpyro.
182
208
  """
183
209
 
210
+ if not callable(prior_distributions):
211
+ # Validate the entry with pydantic
212
+ prior_distributions = PriorDictModel(nested_dict=prior_distributions).nested_dict
213
+
214
+ def prior_distributions_func():
215
+ return build_prior(
216
+ prior_distributions, expand_shape=(len(self._observation_container),)
217
+ )
218
+
219
+ else:
220
+ prior_distributions_func = prior_distributions
221
+
184
222
  def model(observed=True):
185
- prior_params = build_prior(prior_distributions, expand_shape=(len(self._observation_container),))
223
+ prior_params = prior_distributions_func()
186
224
 
225
+ # Iterate over all the observations in our container and build a single numpyro model for each observation
187
226
  for i, (key, observation) in enumerate(self._observation_container.items()):
227
+ # We expect that prior_params contains an array of parameters for each observation
228
+ # They can be identical or different for each observation
188
229
  params = tree_map(lambda x: x[i], prior_params)
189
230
 
190
- obs_model = build_numpyro_model(observation, self.model, self.background_model, name=key, sparse=self.sparse)
231
+ obs_model = build_numpyro_model_for_single_obs(
232
+ observation, self.model, self.background_model, name=key, sparse=self.sparse
233
+ )
234
+
191
235
  obs_model(params, observed=observed)
192
236
 
193
237
  return model
194
238
 
239
+ def transformed_numpyro_model(self, prior_distributions: PriorDictType) -> Callable:
240
+ transform_dict = {}
241
+
242
+ for m, n, val in hk.data_structures.traverse(prior_distributions):
243
+ if isinstance(val, TransformedDistribution):
244
+ transform_dict[f"{m}_{n}"] = TransformReparam()
245
+
246
+ return numpyro.handlers.reparam(
247
+ self.numpyro_model(prior_distributions), config=transform_dict
248
+ )
249
+
195
250
  @abstractmethod
196
- def fit(self, prior_distributions: HaikuDict[Distribution], **kwargs) -> FitResult: ...
251
+ def fit(self, prior_distributions: PriorDictType, **kwargs) -> FitResult: ...
197
252
 
198
253
 
199
254
  class BayesianFitter(ModelFitter):
@@ -204,20 +259,22 @@ class BayesianFitter(ModelFitter):
204
259
 
205
260
  def fit(
206
261
  self,
207
- prior_distributions: HaikuDict[Distribution],
262
+ prior_distributions: PriorDictType,
208
263
  rng_key: int = 0,
209
- num_chains: int = 4,
264
+ num_chains: int = len(jax.devices()),
210
265
  num_warmup: int = 1000,
211
266
  num_samples: int = 1000,
212
267
  max_tree_depth: int = 10,
213
268
  target_accept_prob: float = 0.8,
214
269
  dense_mass: bool = False,
270
+ kernel_kwargs: dict = {},
215
271
  mcmc_kwargs: dict = {},
216
272
  ) -> FitResult:
217
273
  """
218
274
  Fit the model to the data using NUTS sampler from numpyro.
219
275
 
220
- Parameters:
276
+ Parameters
277
+ ----------
221
278
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
222
279
  rng_key: the random key used to initialize the sampler.
223
280
  num_chains: the number of chains to run.
@@ -229,14 +286,12 @@ class BayesianFitter(ModelFitter):
229
286
  mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
230
287
 
231
288
  Returns:
289
+ -------
232
290
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
233
291
  """
234
292
 
235
- transform_dict = {}
236
-
237
- for m, n, val in hk.data_structures.traverse(prior_distributions):
238
- if isinstance(val, TransformedDistribution):
239
- transform_dict[f"{m}_{n}"] = TransformReparam()
293
+ bayesian_model = self.transformed_numpyro_model(prior_distributions)
294
+ # bayesian_model = self.numpyro_model(prior_distributions)
240
295
 
241
296
  chain_kwargs = {
242
297
  "num_warmup": num_warmup,
@@ -244,20 +299,29 @@ class BayesianFitter(ModelFitter):
244
299
  "num_chains": num_chains,
245
300
  }
246
301
 
247
- bayesian_model = numpyro.handlers.reparam(self.numpyro_model(prior_distributions), config=transform_dict)
248
-
249
- kernel = NUTS(bayesian_model, max_tree_depth=max_tree_depth, target_accept_prob=target_accept_prob, dense_mass=dense_mass)
302
+ kernel = NUTS(
303
+ bayesian_model,
304
+ max_tree_depth=max_tree_depth,
305
+ target_accept_prob=target_accept_prob,
306
+ dense_mass=dense_mass,
307
+ **kernel_kwargs,
308
+ )
250
309
 
251
310
  mcmc = MCMC(kernel, **(chain_kwargs | mcmc_kwargs))
252
-
253
311
  keys = random.split(random.PRNGKey(rng_key), 3)
254
312
 
255
313
  mcmc.run(keys[0])
256
- posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(keys[1], observed=False)
314
+ posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
315
+ keys[1], observed=False
316
+ )
257
317
  prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
258
- inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
318
+ inference_data = az.from_numpyro(
319
+ mcmc, prior=prior, posterior_predictive=posterior_predictive
320
+ )
259
321
 
260
- inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
322
+ inference_data = filter_inference_data(
323
+ inference_data, self._observation_container, self.background_model
324
+ )
261
325
 
262
326
  return FitResult(
263
327
  self.model,
@@ -272,61 +336,242 @@ class MinimizationFitter(ModelFitter):
272
336
  """
273
337
  A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
274
338
  algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
275
- Hessian of the log-likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
339
+ Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
276
340
  numpyro.
277
341
  """
278
342
 
279
343
  def fit(
280
344
  self,
281
- prior_distributions: HaikuDict[Distribution],
345
+ prior_distributions: PriorDictType,
282
346
  rng_key: int = 0,
283
- num_iter_max: int = 10_000,
347
+ num_iter_max: int = 100_000,
284
348
  num_samples: int = 1_000,
349
+ solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
350
+ init_params=None,
351
+ refine_first_guess=True,
285
352
  ) -> FitResult:
286
353
  """
287
354
  Fit the model to the data using L-BFGS algorithm.
288
355
 
289
- Parameters:
356
+ Parameters
357
+ ----------
290
358
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
291
359
  rng_key: the random key used to initialize the sampler.
292
360
  num_iter_max: the maximum number of iteration in the minimization algorithm.
293
361
  num_samples: the number of sample to draw from the best-fit covariance.
294
362
 
295
363
  Returns:
364
+ -------
296
365
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
297
366
  """
298
367
 
299
368
  bayesian_model = self.numpyro_model(prior_distributions)
369
+ keys = jax.random.split(PRNGKey(rng_key), 4)
300
370
 
301
- param_info, potential_fn, postprocess_fn, *_ = initialize_model(
302
- PRNGKey(0),
303
- bayesian_model,
304
- model_args=tuple(),
305
- dynamic_args=True, # <- this is important!
306
- )
371
+ if init_params is not None:
372
+ # We initialize the parameters by randomly sampling from the prior
373
+ local_keys = jax.random.split(keys[0], 2)
374
+ starting_value = sample_prior(
375
+ prior_distributions, key=local_keys[0], flat_parameters=True
376
+ )
377
+
378
+ # We update the starting value with the provided init_params
379
+ for m, n, val in hk.data_structures.traverse(init_params):
380
+ if f"{m}_{n}" in starting_value.keys():
381
+ starting_value[f"{m}_{n}"] = val
382
+
383
+ init_params, _ = numpyro.infer.util.find_valid_initial_params(
384
+ local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
385
+ )
386
+
387
+ else:
388
+ init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
389
+
390
+ init_params = init_params[0]
307
391
 
308
- # get negative log-density from the potential function
309
392
  @jax.jit
310
- def nll_fn(position):
311
- func = potential_fn()
312
- return func(position)
393
+ def nll(unconstrained_params, _):
394
+ constrained_params = constrain_fn(
395
+ bayesian_model, tuple(), dict(observed=True), unconstrained_params
396
+ )
313
397
 
314
- solver = jaxopt.LBFGS(fun=nll_fn, maxiter=10_000)
315
- params, state = solver.run(param_info.z)
316
- keys = random.split(random.PRNGKey(rng_key), 3)
398
+ log_likelihood = numpyro.infer.util.log_likelihood(
399
+ model=bayesian_model, posterior_samples=constrained_params
400
+ )
401
+
402
+ # We solve a least square problem, this function ensure that the total residual is indeed the nll
403
+ return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
404
+
405
+ """
406
+ if refine_first_guess:
407
+ with catchtime("Refine_first"):
408
+ solution = optx.least_squares(
409
+ nll,
410
+ optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
411
+ init_params,
412
+ max_steps=1000,
413
+ throw=False
414
+ )
415
+ init_params = solution.value
416
+ """
317
417
 
418
+ if solver == "bfgs":
419
+ solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
420
+ elif solver == "levenberg_marquardt":
421
+ solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
422
+ else:
423
+ raise NotImplementedError(f"{solver} is not implemented")
424
+
425
+ with catchtime("Minimization"):
426
+ solution = optx.least_squares(
427
+ nll,
428
+ solver,
429
+ init_params,
430
+ max_steps=num_iter_max,
431
+ )
432
+
433
+ params = solution.value
318
434
  value_flat, unflatten_fun = ravel_pytree(params)
319
- covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
320
435
 
321
- samples_flat = jax.random.multivariate_normal(keys[0], value_flat, covariance, shape=(num_samples,))
322
- samples = jax.vmap(unflatten_fun)(samples_flat.block_until_ready())
323
- posterior_samples = postprocess_fn()(samples)
436
+ with catchtime("Compute error"):
437
+ precision = jax.hessian(
438
+ lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
439
+ )(value_flat)
324
440
 
325
- posterior_predictive = Predictive(bayesian_model, posterior_samples)(keys[1], observed=False)
326
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
441
+ cov = Covariance.from_precision(precision)
442
+
443
+ samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
444
+
445
+ samples = jax.vmap(unflatten_fun)(samples_flat)
446
+ posterior_samples = jax.jit(
447
+ jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
448
+ )(samples)
449
+
450
+ with catchtime("Posterior"):
451
+ posterior_predictive = Predictive(bayesian_model, posterior_samples)(
452
+ keys[2], observed=False
453
+ )
454
+ prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
455
+ log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
456
+
457
+ def sanitize_chain(chain):
458
+ """
459
+ reshape the samples so that it is arviz compliant with an extra starting dimension
460
+ """
461
+ return tree_map(lambda x: x[None, ...], chain)
462
+
463
+ # We export the observed values to the inference_data
464
+ seeded_model = numpyro.handlers.substitute(
465
+ numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
466
+ substitute_fn=numpyro.infer.init_to_sample,
467
+ )
468
+ trace = numpyro.handlers.trace(seeded_model).get_trace()
469
+ observations = {
470
+ name: site["value"]
471
+ for name, site in trace.items()
472
+ if site["type"] == "sample" and site["is_observed"]
473
+ }
474
+
475
+ with catchtime("InferenceData wrapping"):
476
+ inference_data = az.from_dict(
477
+ sanitize_chain(posterior_samples),
478
+ prior=sanitize_chain(prior),
479
+ posterior_predictive=sanitize_chain(posterior_predictive),
480
+ log_likelihood=sanitize_chain(log_likelihood),
481
+ observed_data=observations,
482
+ )
483
+
484
+ inference_data = filter_inference_data(
485
+ inference_data, self._observation_container, self.background_model
486
+ )
487
+
488
+ return FitResult(
489
+ self.model,
490
+ self._observation_container,
491
+ inference_data,
492
+ self.model.params,
493
+ background_model=self.background_model,
494
+ )
495
+
496
+
497
+ class NestedSamplingFitter(ModelFitter):
498
+ r"""
499
+ A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
500
+ [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
501
+ implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
502
+ Add Citation to jaxns
503
+ """
504
+
505
+ def fit(
506
+ self,
507
+ prior_distributions: PriorDictType,
508
+ rng_key: int = 0,
509
+ num_parallel_workers: int = len(jax.devices()),
510
+ num_samples: int = 1000,
511
+ plot_diagnostics=False,
512
+ verbose=True,
513
+ ) -> FitResult:
514
+ """
515
+ Fit the model to the data using the Phantom-Powered nested sampling algorithm.
516
+
517
+ Parameters:
518
+ prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
519
+ rng_key: the random key used to initialize the sampler.
520
+ num_samples: the number of samples to draw.
521
+
522
+ Returns:
523
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
524
+ """
525
+
526
+ bayesian_model = self.transformed_numpyro_model(prior_distributions)
527
+ keys = random.split(random.PRNGKey(rng_key), 4)
528
+
529
+ ns = NestedSampler(
530
+ bayesian_model,
531
+ constructor_kwargs=dict(
532
+ num_parallel_workers=num_parallel_workers,
533
+ verbose=verbose,
534
+ difficult_model=True,
535
+ # max_samples=1e6,
536
+ # num_live_points=10_000,
537
+ # init_efficiency_threshold=0.5,
538
+ parameter_estimation=True,
539
+ ),
540
+ termination_kwargs=dict(dlogZ=1e-2),
541
+ )
542
+
543
+ ns.run(keys[0])
544
+
545
+ self.ns = ns
546
+
547
+ if plot_diagnostics:
548
+ ns.diagnostics()
549
+
550
+ posterior_samples = ns.get_samples(keys[1], num_samples=num_samples * num_parallel_workers)
327
551
  log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
552
+ posterior_predictive = Predictive(bayesian_model, posterior_samples)(
553
+ keys[2], observed=False
554
+ )
555
+
556
+ prior = Predictive(bayesian_model, num_samples=num_samples * num_parallel_workers)(
557
+ keys[3], observed=False
558
+ )
559
+
560
+ seeded_model = numpyro.handlers.substitute(
561
+ numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
562
+ substitute_fn=numpyro.infer.init_to_sample,
563
+ )
564
+ trace = numpyro.handlers.trace(seeded_model).get_trace()
565
+ observations = {
566
+ name: site["value"]
567
+ for name, site in trace.items()
568
+ if site["type"] == "sample" and site["is_observed"]
569
+ }
328
570
 
329
571
  def sanitize_chain(chain):
572
+ """
573
+ reshape the samples so that it is arviz compliant with an extra starting dimension
574
+ """
330
575
  return tree_map(lambda x: x[None, ...], chain)
331
576
 
332
577
  inference_data = az.from_dict(
@@ -334,9 +579,12 @@ class MinimizationFitter(ModelFitter):
334
579
  prior=sanitize_chain(prior),
335
580
  posterior_predictive=sanitize_chain(posterior_predictive),
336
581
  log_likelihood=sanitize_chain(log_likelihood),
582
+ observed_data=observations,
337
583
  )
338
584
 
339
- inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
585
+ inference_data = filter_inference_data(
586
+ inference_data, self._observation_container, self.background_model
587
+ )
340
588
 
341
589
  return FitResult(
342
590
  self.model,
jaxspec/model/__init__.py CHANGED
@@ -1 +0,0 @@
1
-