jaxspec 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -19,6 +19,7 @@ from numpyro.contrib.nested_sampling import NestedSampler
19
19
  from numpyro.distributions import Distribution, Poisson, TransformedDistribution
20
20
  from numpyro.infer import MCMC, NUTS, Predictive
21
21
  from numpyro.infer.initialization import init_to_value
22
+ from numpyro.infer.inspect import get_model_relations
22
23
  from numpyro.infer.reparam import TransformReparam
23
24
  from numpyro.infer.util import constrain_fn
24
25
  from scipy.stats import Covariance, multivariate_normal
@@ -27,7 +28,7 @@ from .analysis.results import FitResult
27
28
  from .data import ObsConfiguration
28
29
  from .model.abc import SpectralModel
29
30
  from .model.background import BackgroundModel
30
- from .util import catchtime, sample_prior
31
+ from .util import catchtime
31
32
  from .util.typing import PriorDictModel, PriorDictType
32
33
 
33
34
 
@@ -39,17 +40,16 @@ def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
39
40
  parameters = dict(hk.data_structures.to_haiku_dict(prior))
40
41
 
41
42
  for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
42
- match sample:
43
- case Distribution():
44
- parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(
45
- f"{prefix}{m}_{n}", sample
46
- )
47
- case ArrayLike():
48
- parameters[m][n] = jnp.ones(expand_shape) * sample
49
- case _:
50
- raise ValueError(
51
- f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
52
- )
43
+ if isinstance(sample, Distribution):
44
+ parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
45
+
46
+ elif isinstance(sample, ArrayLike):
47
+ parameters[m][n] = jnp.ones(expand_shape) * sample
48
+
49
+ else:
50
+ raise ValueError(
51
+ f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
52
+ )
53
53
 
54
54
  return parameters
55
55
 
@@ -90,7 +90,7 @@ def build_numpyro_model_for_single_obs(
90
90
  with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
91
91
  numpyro.sample(
92
92
  "obs_" + name,
93
- Poisson(countrate + bkg_countrate * obs.folded_backratio.data),
93
+ Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
94
94
  obs=obs.folded_counts.data if observed else None,
95
95
  )
96
96
 
@@ -148,14 +148,15 @@ class CountForwardModel(hk.Module):
148
148
  return jnp.clip(expected_counts, a_min=1e-6)
149
149
 
150
150
 
151
- class ModelFitter(ABC):
151
+ class BayesianModel:
152
152
  """
153
- Abstract class to fit a model to a given set of observation.
153
+ Class to fit a model to a given set of observation.
154
154
  """
155
155
 
156
156
  def __init__(
157
157
  self,
158
158
  model: SpectralModel,
159
+ prior_distributions: PriorDictType | Callable,
159
160
  observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
160
161
  background_model: BackgroundModel = None,
161
162
  sparsify_matrix: bool = False,
@@ -163,9 +164,10 @@ class ModelFitter(ABC):
163
164
  """
164
165
  Initialize the fitter.
165
166
 
166
- Parameters
167
- ----------
167
+ Parameters:
168
168
  model: the spectral model to fit.
169
+ prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
170
+ callable function that returns parameter samples.
169
171
  observations: the observations to fit the model to.
170
172
  background_model: the background model to fit.
171
173
  sparsify_matrix: whether to sparsify the transfer matrix.
@@ -176,8 +178,20 @@ class ModelFitter(ABC):
176
178
  self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
177
179
  self.sparse = sparsify_matrix
178
180
 
181
+ if not callable(prior_distributions):
182
+ # Validate the entry with pydantic
183
+ prior = PriorDictModel(nested_dict=prior_distributions).nested_dict
184
+
185
+ def prior_distributions_func():
186
+ return build_prior(prior, expand_shape=(len(self.observation_container),))
187
+
188
+ else:
189
+ prior_distributions_func = prior_distributions
190
+
191
+ self.prior_distributions_func = prior_distributions_func
192
+
179
193
  @property
180
- def _observation_container(self) -> dict[str, ObsConfiguration]:
194
+ def observation_container(self) -> dict[str, ObsConfiguration]:
181
195
  """
182
196
  The observations used in the fit as a dictionary of observations.
183
197
  """
@@ -194,36 +208,21 @@ class ModelFitter(ABC):
194
208
  else:
195
209
  raise ValueError(f"Invalid type for observations : {type(self._observations)}")
196
210
 
197
- def numpyro_model(self, prior_distributions: PriorDictType | Callable) -> Callable:
211
+ @property
212
+ def numpyro_model(self) -> Callable:
198
213
  """
199
214
  Build the numpyro model using the observed data, the prior distributions and the spectral model.
200
215
 
201
- Parameters
202
- ----------
203
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
204
-
205
216
  Returns:
206
217
  -------
207
218
  A model function that can be used with numpyro.
208
219
  """
209
220
 
210
- if not callable(prior_distributions):
211
- # Validate the entry with pydantic
212
- prior_distributions = PriorDictModel(nested_dict=prior_distributions).nested_dict
213
-
214
- def prior_distributions_func():
215
- return build_prior(
216
- prior_distributions, expand_shape=(len(self._observation_container),)
217
- )
218
-
219
- else:
220
- prior_distributions_func = prior_distributions
221
-
222
221
  def model(observed=True):
223
- prior_params = prior_distributions_func()
222
+ prior_params = self.prior_distributions_func()
224
223
 
225
224
  # Iterate over all the observations in our container and build a single numpyro model for each observation
226
- for i, (key, observation) in enumerate(self._observation_container.items()):
225
+ for i, (key, observation) in enumerate(self.observation_container.items()):
227
226
  # We expect that prior_params contains an array of parameters for each observation
228
227
  # They can be identical or different for each observation
229
228
  params = tree_map(lambda x: x[i], prior_params)
@@ -236,22 +235,29 @@ class ModelFitter(ABC):
236
235
 
237
236
  return model
238
237
 
239
- def transformed_numpyro_model(self, prior_distributions: PriorDictType) -> Callable:
238
+ @property
239
+ def transformed_numpyro_model(self) -> Callable:
240
240
  transform_dict = {}
241
241
 
242
- for m, n, val in hk.data_structures.traverse(prior_distributions):
243
- if isinstance(val, TransformedDistribution):
244
- transform_dict[f"{m}_{n}"] = TransformReparam()
242
+ relations = get_model_relations(self.numpyro_model)
243
+ distributions = {
244
+ parameter: getattr(numpyro.distributions, value, None)
245
+ for parameter, value in relations["sample_dist"].items()
246
+ }
247
+
248
+ for parameter, distribution in distributions.items():
249
+ if isinstance(distribution, TransformedDistribution):
250
+ transform_dict[parameter] = TransformReparam()
251
+
252
+ return numpyro.handlers.reparam(self.numpyro_model, config=transform_dict)
245
253
 
246
- return numpyro.handlers.reparam(
247
- self.numpyro_model(prior_distributions), config=transform_dict
248
- )
249
254
 
255
+ class BayesianModelFitter(BayesianModel, ABC):
250
256
  @abstractmethod
251
- def fit(self, prior_distributions: PriorDictType, **kwargs) -> FitResult: ...
257
+ def fit(self, **kwargs) -> FitResult: ...
252
258
 
253
259
 
254
- class BayesianFitter(ModelFitter):
260
+ class NUTSFitter(BayesianModelFitter):
255
261
  """
256
262
  A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
257
263
  from numpyro to perform the inference on the model parameters.
@@ -259,7 +265,6 @@ class BayesianFitter(ModelFitter):
259
265
 
260
266
  def fit(
261
267
  self,
262
- prior_distributions: PriorDictType,
263
268
  rng_key: int = 0,
264
269
  num_chains: int = len(jax.devices()),
265
270
  num_warmup: int = 1000,
@@ -273,9 +278,7 @@ class BayesianFitter(ModelFitter):
273
278
  """
274
279
  Fit the model to the data using NUTS sampler from numpyro.
275
280
 
276
- Parameters
277
- ----------
278
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
281
+ Parameters:
279
282
  rng_key: the random key used to initialize the sampler.
280
283
  num_chains: the number of chains to run.
281
284
  num_warmup: the number of warmup steps.
@@ -286,11 +289,10 @@ class BayesianFitter(ModelFitter):
286
289
  mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
287
290
 
288
291
  Returns:
289
- -------
290
292
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
291
293
  """
292
294
 
293
- bayesian_model = self.transformed_numpyro_model(prior_distributions)
295
+ bayesian_model = self.transformed_numpyro_model
294
296
  # bayesian_model = self.numpyro_model(prior_distributions)
295
297
 
296
298
  chain_kwargs = {
@@ -311,28 +313,30 @@ class BayesianFitter(ModelFitter):
311
313
  keys = random.split(random.PRNGKey(rng_key), 3)
312
314
 
313
315
  mcmc.run(keys[0])
316
+
314
317
  posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
315
318
  keys[1], observed=False
316
319
  )
320
+
317
321
  prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
322
+
318
323
  inference_data = az.from_numpyro(
319
324
  mcmc, prior=prior, posterior_predictive=posterior_predictive
320
325
  )
321
326
 
322
327
  inference_data = filter_inference_data(
323
- inference_data, self._observation_container, self.background_model
328
+ inference_data, self.observation_container, self.background_model
324
329
  )
325
330
 
326
331
  return FitResult(
327
- self.model,
328
- self._observation_container,
332
+ self,
329
333
  inference_data,
330
334
  self.model.params,
331
335
  background_model=self.background_model,
332
336
  )
333
337
 
334
338
 
335
- class MinimizationFitter(ModelFitter):
339
+ class MinimizationFitter(BayesianModelFitter):
336
340
  """
337
341
  A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
338
342
  algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
@@ -342,7 +346,6 @@ class MinimizationFitter(ModelFitter):
342
346
 
343
347
  def fit(
344
348
  self,
345
- prior_distributions: PriorDictType,
346
349
  rng_key: int = 0,
347
350
  num_iter_max: int = 100_000,
348
351
  num_samples: int = 1_000,
@@ -353,27 +356,24 @@ class MinimizationFitter(ModelFitter):
353
356
  """
354
357
  Fit the model to the data using L-BFGS algorithm.
355
358
 
356
- Parameters
357
- ----------
358
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
359
+ Parameters:
359
360
  rng_key: the random key used to initialize the sampler.
360
361
  num_iter_max: the maximum number of iteration in the minimization algorithm.
361
362
  num_samples: the number of sample to draw from the best-fit covariance.
362
363
 
363
364
  Returns:
364
- -------
365
365
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
366
366
  """
367
367
 
368
- bayesian_model = self.numpyro_model(prior_distributions)
368
+ bayesian_model = self.numpyro_model
369
369
  keys = jax.random.split(PRNGKey(rng_key), 4)
370
370
 
371
371
  if init_params is not None:
372
372
  # We initialize the parameters by randomly sampling from the prior
373
373
  local_keys = jax.random.split(keys[0], 2)
374
- starting_value = sample_prior(
375
- prior_distributions, key=local_keys[0], flat_parameters=True
376
- )
374
+
375
+ with numpyro.handlers.seed(rng_seed=local_keys[0]):
376
+ starting_value = self.prior_distributions_func()
377
377
 
378
378
  # We update the starting value with the provided init_params
379
379
  for m, n, val in hk.data_structures.traverse(init_params):
@@ -482,19 +482,18 @@ class MinimizationFitter(ModelFitter):
482
482
  )
483
483
 
484
484
  inference_data = filter_inference_data(
485
- inference_data, self._observation_container, self.background_model
485
+ inference_data, self.observation_container, self.background_model
486
486
  )
487
487
 
488
488
  return FitResult(
489
- self.model,
490
- self._observation_container,
489
+ self,
491
490
  inference_data,
492
491
  self.model.params,
493
492
  background_model=self.background_model,
494
493
  )
495
494
 
496
495
 
497
- class NestedSamplingFitter(ModelFitter):
496
+ class NestedSamplingFitter(BayesianModelFitter):
498
497
  r"""
499
498
  A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
500
499
  [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
@@ -504,18 +503,16 @@ class NestedSamplingFitter(ModelFitter):
504
503
 
505
504
  def fit(
506
505
  self,
507
- prior_distributions: PriorDictType,
508
506
  rng_key: int = 0,
509
- num_parallel_workers: int = len(jax.devices()),
510
507
  num_samples: int = 1000,
511
508
  plot_diagnostics=False,
509
+ termination_kwargs: dict | None = None,
512
510
  verbose=True,
513
511
  ) -> FitResult:
514
512
  """
515
513
  Fit the model to the data using the Phantom-Powered nested sampling algorithm.
516
514
 
517
515
  Parameters:
518
- prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
519
516
  rng_key: the random key used to initialize the sampler.
520
517
  num_samples: the number of samples to draw.
521
518
 
@@ -523,39 +520,34 @@ class NestedSamplingFitter(ModelFitter):
523
520
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
524
521
  """
525
522
 
526
- bayesian_model = self.transformed_numpyro_model(prior_distributions)
523
+ bayesian_model = self.transformed_numpyro_model
527
524
  keys = random.split(random.PRNGKey(rng_key), 4)
528
525
 
529
526
  ns = NestedSampler(
530
527
  bayesian_model,
531
528
  constructor_kwargs=dict(
532
- num_parallel_workers=num_parallel_workers,
529
+ num_parallel_workers=1,
533
530
  verbose=verbose,
534
531
  difficult_model=True,
535
- # max_samples=1e6,
536
- # num_live_points=10_000,
537
- # init_efficiency_threshold=0.5,
532
+ max_samples=1e6,
538
533
  parameter_estimation=True,
534
+ num_live_points=1_000,
539
535
  ),
540
- termination_kwargs=dict(dlogZ=1e-2),
536
+ termination_kwargs=termination_kwargs if termination_kwargs else dict(),
541
537
  )
542
538
 
543
539
  ns.run(keys[0])
544
540
 
545
- self.ns = ns
546
-
547
541
  if plot_diagnostics:
548
542
  ns.diagnostics()
549
543
 
550
- posterior_samples = ns.get_samples(keys[1], num_samples=num_samples * num_parallel_workers)
544
+ posterior_samples = ns.get_samples(keys[1], num_samples=num_samples)
551
545
  log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
552
546
  posterior_predictive = Predictive(bayesian_model, posterior_samples)(
553
547
  keys[2], observed=False
554
548
  )
555
549
 
556
- prior = Predictive(bayesian_model, num_samples=num_samples * num_parallel_workers)(
557
- keys[3], observed=False
558
- )
550
+ prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
559
551
 
560
552
  seeded_model = numpyro.handlers.substitute(
561
553
  numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
@@ -583,12 +575,11 @@ class NestedSamplingFitter(ModelFitter):
583
575
  )
584
576
 
585
577
  inference_data = filter_inference_data(
586
- inference_data, self._observation_container, self.background_model
578
+ inference_data, self.observation_container, self.background_model
587
579
  )
588
580
 
589
581
  return FitResult(
590
- self.model,
591
- self._observation_container,
582
+ self,
592
583
  inference_data,
593
584
  self.model.params,
594
585
  background_model=self.background_model,