jaxspec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,35 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ from jax.typing import ArrayLike
5
+ from scipy.stats import nbinom
6
+
7
+
8
+ def _plot_poisson_data_with_error(
9
+ ax: plt.Axes,
10
+ x_bins: ArrayLike,
11
+ y: ArrayLike,
12
+ percentiles: tuple = (16, 84),
13
+ ):
14
+ """
15
+ Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
16
+ distributed according to a Gamma RV.
17
+ """
18
+ y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
19
+ y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)
20
+
21
+ ax_to_plot = ax.errorbar(
22
+ np.sqrt(x_bins[0] * x_bins[1]),
23
+ y,
24
+ xerr=np.abs(x_bins - np.sqrt(x_bins[0] * x_bins[1])),
25
+ yerr=[
26
+ y - y_low,
27
+ y_high - y,
28
+ ],
29
+ color="black",
30
+ linestyle="none",
31
+ alpha=0.3,
32
+ capsize=2,
33
+ )
34
+
35
+ return ax_to_plot
jaxspec/data/obsconf.py CHANGED
@@ -229,5 +229,39 @@ class ObsConfiguration(xr.Dataset):
229
229
  attrs=observation.attrs | instrument.attrs,
230
230
  )
231
231
 
232
+ @classmethod
233
+ def mock_from_instrument(
234
+ cls,
235
+ instrument: Instrument,
236
+ exposure: float,
237
+ low_energy: float = 1e-20,
238
+ high_energy: float = 1e20,
239
+ ):
240
+ """
241
+ Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
242
+
243
+ Parameters:
244
+ instrument: The instrument object.
245
+ exposure: The total exposure of the mock observation.
246
+ low_energy: The lower bound of the energy range to consider.
247
+ high_energy: The upper bound of the energy range to consider.
248
+ """
249
+
250
+ n_channels = len(instrument.coords["instrument_channel"])
251
+
252
+ observation = Observation.from_matrix(
253
+ np.zeros(n_channels),
254
+ sparse.eye(n_channels),
255
+ np.arange(n_channels),
256
+ np.zeros(n_channels, dtype=bool),
257
+ exposure,
258
+ backratio=np.ones(n_channels),
259
+ attributes={"description": "Mock observation"} | instrument.attrs,
260
+ )
261
+
262
+ return cls.from_instrument(
263
+ instrument, observation, low_energy=low_energy, high_energy=high_energy
264
+ )
265
+
232
266
  def plot_counts(self, **kwargs):
233
267
  return self.folded_counts.plot.step(x="e_min_folded", where="post", **kwargs)
jaxspec/fit.py CHANGED
@@ -10,29 +10,27 @@ import arviz as az
10
10
  import haiku as hk
11
11
  import jax
12
12
  import jax.numpy as jnp
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
13
15
  import numpyro
14
- import optimistix as optx
15
16
 
16
17
  from jax import random
17
18
  from jax.experimental.sparse import BCOO
18
- from jax.flatten_util import ravel_pytree
19
19
  from jax.random import PRNGKey
20
20
  from jax.tree_util import tree_map
21
21
  from jax.typing import ArrayLike
22
22
  from numpyro.contrib.nested_sampling import NestedSampler
23
23
  from numpyro.distributions import Distribution, Poisson, TransformedDistribution
24
24
  from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
25
- from numpyro.infer.initialization import init_to_value
26
25
  from numpyro.infer.inspect import get_model_relations
27
26
  from numpyro.infer.reparam import TransformReparam
28
- from numpyro.infer.util import constrain_fn, log_density
29
- from scipy.stats import Covariance, multivariate_normal
27
+ from numpyro.infer.util import log_density
30
28
 
29
+ from .analysis._plot import _plot_poisson_data_with_error
31
30
  from .analysis.results import FitResult
32
31
  from .data import ObsConfiguration
33
32
  from .model.abc import SpectralModel
34
33
  from .model.background import BackgroundModel
35
- from .util import catchtime
36
34
  from .util.typing import PriorDictModel, PriorDictType
37
35
 
38
36
 
@@ -101,27 +99,6 @@ def build_numpyro_model_for_single_obs(
101
99
  return numpyro_model
102
100
 
103
101
 
104
- def filter_inference_data(
105
- inference_data, observation_container, background_model=None
106
- ) -> az.InferenceData:
107
- predictive_parameters = []
108
-
109
- for key, value in observation_container.items():
110
- if background_model is not None:
111
- predictive_parameters.append(f"obs_{key}")
112
- predictive_parameters.append(f"bkg_{key}")
113
- else:
114
- predictive_parameters.append(f"obs_{key}")
115
-
116
- inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
117
-
118
- parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
119
- inference_data.posterior = inference_data.posterior[parameters]
120
- inference_data.prior = inference_data.prior[parameters]
121
-
122
- return inference_data
123
-
124
-
125
102
  class CountForwardModel(hk.Module):
126
103
  """
127
104
  A haiku module which allows to build the function that simulates the measured counts
@@ -154,7 +131,8 @@ class CountForwardModel(hk.Module):
154
131
 
155
132
  class BayesianModel:
156
133
  """
157
- Class to fit a model to a given set of observation.
134
+ Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
135
+ and compute the log-likelihood and posterior probability.
158
136
  """
159
137
 
160
138
  def __init__(
@@ -166,6 +144,8 @@ class BayesianModel:
166
144
  sparsify_matrix: bool = False,
167
145
  ):
168
146
  """
147
+ Build a Bayesian model for a given spectral model and observations.
148
+
169
149
  Parameters:
170
150
  model: the spectral model to fit.
171
151
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
@@ -191,7 +171,7 @@ class BayesianModel:
191
171
  prior_distributions_func = prior_distributions
192
172
 
193
173
  self.prior_distributions_func = prior_distributions_func
194
- self.init_params = self.get_initial_params()
174
+ self.init_params = self.prior_samples()
195
175
 
196
176
  @cached_property
197
177
  def observation_container(self) -> dict[str, ObsConfiguration]:
@@ -215,9 +195,6 @@ class BayesianModel:
215
195
  def numpyro_model(self) -> Callable:
216
196
  """
217
197
  Build the numpyro model using the observed data, the prior distributions and the spectral model.
218
-
219
- Returns:
220
- A model function that can be used with numpyro.
221
198
  """
222
199
 
223
200
  def model(observed=True):
@@ -257,9 +234,6 @@ class BayesianModel:
257
234
  def log_likelihood_per_obs(self) -> Callable:
258
235
  """
259
236
  Build the log likelihood function for each bins in each observation.
260
-
261
- Returns:
262
- Callable log-likelihood function.
263
237
  """
264
238
 
265
239
  @jax.jit
@@ -316,6 +290,16 @@ class BayesianModel:
316
290
  observed_sites = relations["observed"]
317
291
  return [site for site in all_sites if site not in observed_sites]
318
292
 
293
+ @cached_property
294
+ def observation_names(self) -> list[str]:
295
+ """
296
+ List of the observations.
297
+ """
298
+ relations = get_model_relations(self.numpyro_model)
299
+ all_sites = relations["sample_sample"].keys()
300
+ observed_sites = relations["observed"]
301
+ return [site for site in all_sites if site in observed_sites]
302
+
319
303
  def array_to_dict(self, theta):
320
304
  """
321
305
  Convert an array of parameters to a dictionary of parameters.
@@ -339,7 +323,7 @@ class BayesianModel:
339
323
 
340
324
  return theta
341
325
 
342
- def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1):
326
+ def prior_samples(self, key: PRNGKey = PRNGKey(0), num_samples: int = 100):
343
327
  """
344
328
  Get initial parameters for the model by sampling from the prior distribution
345
329
 
@@ -348,9 +332,84 @@ class BayesianModel:
348
332
  num_samples: the number of samples to draw from the prior.
349
333
  """
350
334
 
351
- return Predictive(
352
- self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
353
- )(key, observed=False)
335
+ @jax.jit
336
+ def prior_sample(key):
337
+ return Predictive(
338
+ self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
339
+ )(key, observed=False)
340
+
341
+ return prior_sample(key)
342
+
343
+ def mock_observations(self, parameters, key: PRNGKey = PRNGKey(0)):
344
+ @jax.jit
345
+ def fakeit(key, parameters):
346
+ return Predictive(
347
+ self.numpyro_model,
348
+ return_sites=self.observation_names,
349
+ posterior_samples=parameters,
350
+ )(key, observed=False)
351
+
352
+ return fakeit(key, parameters)
353
+
354
+ def prior_predictive_coverage(
355
+ self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
356
+ ):
357
+ """
358
+ Check if the prior distribution include the observed data.
359
+ """
360
+ key_prior, key_posterior = jax.random.split(key, 2)
361
+ prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
362
+ posterior_observations = self.mock_observations(prior_params, key=key_posterior)
363
+
364
+ for key, value in self.observation_container.items():
365
+ fig, axs = plt.subplots(
366
+ nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
367
+ )
368
+
369
+ _plot_poisson_data_with_error(
370
+ axs[0],
371
+ value.out_energies,
372
+ value.folded_counts.values,
373
+ percentiles=percentiles,
374
+ )
375
+
376
+ axs[0].stairs(
377
+ np.max(posterior_observations["obs_" + key], axis=0),
378
+ edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
379
+ baseline=np.min(posterior_observations["obs_" + key], axis=0),
380
+ alpha=0.3,
381
+ fill=True,
382
+ color=(0.15, 0.25, 0.45),
383
+ )
384
+
385
+ # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
386
+ counts = posterior_observations["obs_" + key]
387
+ observed = value.folded_counts.values
388
+
389
+ num_samples = counts.shape[0]
390
+
391
+ less_than_obs = (counts < observed).sum(axis=0)
392
+ equal_to_obs = (counts == observed).sum(axis=0)
393
+
394
+ rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
395
+
396
+ axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
397
+
398
+ axs[1].plot(
399
+ (value.out_energies.min(), value.out_energies.max()),
400
+ (50, 50),
401
+ color="black",
402
+ linestyle="--",
403
+ )
404
+
405
+ axs[1].set_xlabel("Energy (keV)")
406
+ axs[0].set_ylabel("Counts")
407
+ axs[1].set_ylabel("Rank (%)")
408
+ axs[1].set_ylim(0, 100)
409
+ axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
410
+ axs[0].loglog()
411
+ plt.suptitle(f"Prior Predictive coverage for {key}")
412
+ plt.show()
354
413
 
355
414
 
356
415
  class BayesianModelFitter(BayesianModel, ABC):
@@ -359,11 +418,20 @@ class BayesianModelFitter(BayesianModel, ABC):
359
418
  posterior_samples,
360
419
  num_chains: int = 1,
361
420
  num_predictive_samples: int = 1000,
362
- key: PRNGKey = PRNGKey(0),
421
+ key: PRNGKey = PRNGKey(42),
363
422
  use_transformed_model: bool = False,
423
+ filter_inference_data: bool = True,
364
424
  ) -> az.InferenceData:
365
425
  """
366
- Build an InferenceData object from the posterior samples.
426
+ Build an [InferenceData][arviz.InferenceData] object from posterior samples.
427
+
428
+ Parameters:
429
+ posterior_samples: the samples from the posterior distribution.
430
+ num_chains: the number of chains used to sample the posterior.
431
+ num_predictive_samples: the number of samples to draw from the prior.
432
+ key: the random key used to initialize the sampler.
433
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
434
+ filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
367
435
  """
368
436
 
369
437
  numpyro_model = (
@@ -409,7 +477,7 @@ class BayesianModelFitter(BayesianModel, ABC):
409
477
  key: reshape_first_dimension(value) for key, value in log_likelihood.items()
410
478
  }
411
479
 
412
- return az.from_dict(
480
+ inference_data = az.from_dict(
413
481
  posterior_samples,
414
482
  prior=prior,
415
483
  posterior_predictive=posterior_predictive,
@@ -417,81 +485,42 @@ class BayesianModelFitter(BayesianModel, ABC):
417
485
  observed_data=observations,
418
486
  )
419
487
 
420
- @abstractmethod
421
- def fit(self, **kwargs) -> FitResult: ...
422
-
423
-
424
- class NUTSFitter(BayesianModelFitter):
425
- """
426
- A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
427
- from numpyro to perform the inference on the model parameters.
428
- """
488
+ return (
489
+ self.filter_inference_data(inference_data) if filter_inference_data else inference_data
490
+ )
429
491
 
430
- def fit(
492
+ def filter_inference_data(
431
493
  self,
432
- rng_key: int = 0,
433
- num_chains: int = len(jax.devices()),
434
- num_warmup: int = 1000,
435
- num_samples: int = 1000,
436
- max_tree_depth: int = 10,
437
- target_accept_prob: float = 0.8,
438
- dense_mass: bool = False,
439
- kernel_kwargs: dict = {},
440
- mcmc_kwargs: dict = {},
441
- ) -> FitResult:
494
+ inference_data: az.InferenceData,
495
+ ) -> az.InferenceData:
442
496
  """
443
- Fit the model to the data using NUTS sampler from numpyro.
497
+ Filter the inference data to keep only the relevant parameters for the observations.
444
498
 
445
- Parameters:
446
- rng_key: the random key used to initialize the sampler.
447
- num_chains: the number of chains to run.
448
- num_warmup: the number of warmup steps.
449
- num_samples: the number of samples to draw.
450
- max_tree_depth: the recursion depth of NUTS sampler.
451
- target_accept_prob: the target acceptance probability for the NUTS sampler.
452
- dense_mass: whether to use a dense mass for the NUTS sampler.
453
- mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
454
-
455
- Returns:
456
- A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
499
+ - Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
500
+ - Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
457
501
  """
458
502
 
459
- bayesian_model = self.transformed_numpyro_model
460
- # bayesian_model = self.numpyro_model(prior_distributions)
461
-
462
- chain_kwargs = {
463
- "num_warmup": num_warmup,
464
- "num_samples": num_samples,
465
- "num_chains": num_chains,
466
- }
467
-
468
- kernel = NUTS(
469
- bayesian_model,
470
- max_tree_depth=max_tree_depth,
471
- target_accept_prob=target_accept_prob,
472
- dense_mass=dense_mass,
473
- **kernel_kwargs,
474
- )
503
+ predictive_parameters = []
475
504
 
476
- mcmc = MCMC(kernel, **(chain_kwargs | mcmc_kwargs))
477
- keys = random.split(random.PRNGKey(rng_key), 3)
505
+ for key, value in self.observation_container.items():
506
+ if self.background_model is not None:
507
+ predictive_parameters.append(f"obs_{key}")
508
+ predictive_parameters.append(f"bkg_{key}")
509
+ else:
510
+ predictive_parameters.append(f"obs_{key}")
478
511
 
479
- mcmc.run(keys[0])
512
+ inference_data.posterior_predictive = inference_data.posterior_predictive[
513
+ predictive_parameters
514
+ ]
480
515
 
481
- posterior = mcmc.get_samples()
516
+ parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
517
+ inference_data.posterior = inference_data.posterior[parameters]
518
+ inference_data.prior = inference_data.prior[parameters]
482
519
 
483
- inference_data = filter_inference_data(
484
- self.build_inference_data(posterior, num_chains=num_chains),
485
- self.observation_container,
486
- self.background_model,
487
- )
520
+ return inference_data
488
521
 
489
- return FitResult(
490
- self,
491
- inference_data,
492
- self.model.params,
493
- background_model=self.background_model,
494
- )
522
+ @abstractmethod
523
+ def fit(self, **kwargs) -> FitResult: ...
495
524
 
496
525
 
497
526
  class MCMCFitter(BayesianModelFitter):
@@ -513,6 +542,7 @@ class MCMCFitter(BayesianModelFitter):
513
542
  num_warmup: int = 1000,
514
543
  num_samples: int = 1000,
515
544
  sampler: Literal["nuts", "aies", "ess"] = "nuts",
545
+ use_transformed_model: bool = True,
516
546
  kernel_kwargs: dict = {},
517
547
  mcmc_kwargs: dict = {},
518
548
  ) -> FitResult:
@@ -524,17 +554,18 @@ class MCMCFitter(BayesianModelFitter):
524
554
  num_chains: the number of chains to run.
525
555
  num_warmup: the number of warmup steps.
526
556
  num_samples: the number of samples to draw.
527
- max_tree_depth: the recursion depth of NUTS sampler.
528
- target_accept_prob: the target acceptance probability for the NUTS sampler.
529
- dense_mass: whether to use a dense mass for the NUTS sampler.
557
+ sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
558
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
559
+ kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
530
560
  mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
531
561
 
532
562
  Returns:
533
563
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
534
564
  """
535
565
 
536
- bayesian_model = self.transformed_numpyro_model
537
- # bayesian_model = self.numpyro_model(prior_distributions)
566
+ bayesian_model = (
567
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
568
+ )
538
569
 
539
570
  chain_kwargs = {
540
571
  "num_warmup": num_warmup,
@@ -557,10 +588,8 @@ class MCMCFitter(BayesianModelFitter):
557
588
 
558
589
  posterior = mcmc.get_samples()
559
590
 
560
- inference_data = filter_inference_data(
561
- self.build_inference_data(posterior, num_chains=num_chains),
562
- self.observation_container,
563
- self.background_model,
591
+ inference_data = self.build_inference_data(
592
+ posterior, num_chains=num_chains, use_transformed_model=True
564
593
  )
565
594
 
566
595
  return FitResult(
@@ -571,175 +600,22 @@ class MCMCFitter(BayesianModelFitter):
571
600
  )
572
601
 
573
602
 
574
- class MinimizationFitter(BayesianModelFitter):
575
- """
576
- A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
577
- algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
578
- Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
579
- numpyro.
580
- """
581
-
582
- def fit(
583
- self,
584
- rng_key: int = 0,
585
- num_iter_max: int = 100_000,
586
- num_samples: int = 1_000,
587
- solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
588
- init_params=None,
589
- refine_first_guess=True,
590
- ) -> FitResult:
591
- """
592
- Fit the model to the data using L-BFGS algorithm.
593
-
594
- Parameters:
595
- rng_key: the random key used to initialize the sampler.
596
- num_iter_max: the maximum number of iteration in the minimization algorithm.
597
- num_samples: the number of sample to draw from the best-fit covariance.
598
-
599
- Returns:
600
- A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
601
- """
602
-
603
- bayesian_model = self.numpyro_model
604
- keys = jax.random.split(PRNGKey(rng_key), 4)
605
-
606
- if init_params is not None:
607
- # We initialize the parameters by randomly sampling from the prior
608
- local_keys = jax.random.split(keys[0], 2)
609
-
610
- with numpyro.handlers.seed(rng_seed=local_keys[0]):
611
- starting_value = self.prior_distributions_func()
612
-
613
- # We update the starting value with the provided init_params
614
- for m, n, val in hk.data_structures.traverse(init_params):
615
- if f"{m}_{n}" in starting_value.keys():
616
- starting_value[f"{m}_{n}"] = val
617
-
618
- init_params, _ = numpyro.infer.util.find_valid_initial_params(
619
- local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
620
- )
621
-
622
- else:
623
- init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
624
-
625
- init_params = init_params[0]
626
-
627
- @jax.jit
628
- def nll(unconstrained_params, _):
629
- constrained_params = constrain_fn(
630
- bayesian_model, tuple(), dict(observed=True), unconstrained_params
631
- )
632
-
633
- log_likelihood = numpyro.infer.util.log_likelihood(
634
- model=bayesian_model, posterior_samples=constrained_params
635
- )
636
-
637
- # We solve a least square problem, this function ensure that the total residual is indeed the nll
638
- return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
639
-
640
- """
641
- if refine_first_guess:
642
- with catchtime("Refine_first"):
643
- solution = optx.least_squares(
644
- nll,
645
- optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
646
- init_params,
647
- max_steps=1000,
648
- throw=False
649
- )
650
- init_params = solution.value
651
- """
652
-
653
- if solver == "bfgs":
654
- solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
655
- elif solver == "levenberg_marquardt":
656
- solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
657
- else:
658
- raise NotImplementedError(f"{solver} is not implemented")
659
-
660
- with catchtime("Minimization"):
661
- solution = optx.least_squares(
662
- nll,
663
- solver,
664
- init_params,
665
- max_steps=num_iter_max,
666
- )
667
-
668
- params = solution.value
669
- value_flat, unflatten_fun = ravel_pytree(params)
670
-
671
- with catchtime("Compute error"):
672
- precision = jax.hessian(
673
- lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
674
- )(value_flat)
675
-
676
- cov = Covariance.from_precision(precision)
677
-
678
- samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
679
-
680
- samples = jax.vmap(unflatten_fun)(samples_flat)
681
- posterior_samples = jax.jit(
682
- jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
683
- )(samples)
684
-
685
- with catchtime("Posterior"):
686
- posterior_predictive = Predictive(bayesian_model, posterior_samples)(
687
- keys[2], observed=False
688
- )
689
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
690
- log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
691
-
692
- def sanitize_chain(chain):
693
- """
694
- reshape the samples so that it is arviz compliant with an extra starting dimension
695
- """
696
- return tree_map(lambda x: x[None, ...], chain)
697
-
698
- # We export the observed values to the inference_data
699
- seeded_model = numpyro.handlers.substitute(
700
- numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
701
- substitute_fn=numpyro.infer.init_to_sample,
702
- )
703
- trace = numpyro.handlers.trace(seeded_model).get_trace()
704
- observations = {
705
- name: site["value"]
706
- for name, site in trace.items()
707
- if site["type"] == "sample" and site["is_observed"]
708
- }
709
-
710
- with catchtime("InferenceData wrapping"):
711
- inference_data = az.from_dict(
712
- sanitize_chain(posterior_samples),
713
- prior=sanitize_chain(prior),
714
- posterior_predictive=sanitize_chain(posterior_predictive),
715
- log_likelihood=sanitize_chain(log_likelihood),
716
- observed_data=observations,
717
- )
718
-
719
- inference_data = filter_inference_data(
720
- inference_data, self.observation_container, self.background_model
721
- )
722
-
723
- return FitResult(
724
- self,
725
- inference_data,
726
- self.model.params,
727
- background_model=self.background_model,
728
- )
729
-
730
-
731
- class NestedSamplingFitter(BayesianModelFitter):
603
+ class NSFitter(BayesianModelFitter):
732
604
  r"""
733
605
  A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
734
606
  [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
735
607
  implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
736
- Add Citation to jaxns
608
+
609
+ !!! info
610
+ Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
611
+
737
612
  """
738
613
 
739
614
  def fit(
740
615
  self,
741
616
  rng_key: int = 0,
742
617
  num_samples: int = 1000,
618
+ num_live_points: int = 1000,
743
619
  plot_diagnostics=False,
744
620
  termination_kwargs: dict | None = None,
745
621
  verbose=True,
@@ -750,6 +626,10 @@ class NestedSamplingFitter(BayesianModelFitter):
750
626
  Parameters:
751
627
  rng_key: the random key used to initialize the sampler.
752
628
  num_samples: the number of samples to draw.
629
+ num_live_points: the number of live points to use at the start of the NS algorithm.
630
+ plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
631
+ termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
632
+ verbose: whether to print the progress of the NS algorithm.
753
633
 
754
634
  Returns:
755
635
  A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
@@ -766,7 +646,7 @@ class NestedSamplingFitter(BayesianModelFitter):
766
646
  difficult_model=True,
767
647
  max_samples=1e6,
768
648
  parameter_estimation=True,
769
- num_live_points=1_000,
649
+ num_live_points=num_live_points,
770
650
  ),
771
651
  termination_kwargs=termination_kwargs if termination_kwargs else dict(),
772
652
  )
@@ -776,41 +656,9 @@ class NestedSamplingFitter(BayesianModelFitter):
776
656
  if plot_diagnostics:
777
657
  ns.diagnostics()
778
658
 
779
- posterior_samples = ns.get_samples(keys[1], num_samples=num_samples)
780
- log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
781
- posterior_predictive = Predictive(bayesian_model, posterior_samples)(
782
- keys[2], observed=False
783
- )
784
-
785
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
786
-
787
- seeded_model = numpyro.handlers.substitute(
788
- numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
789
- substitute_fn=numpyro.infer.init_to_sample,
790
- )
791
- trace = numpyro.handlers.trace(seeded_model).get_trace()
792
- observations = {
793
- name: site["value"]
794
- for name, site in trace.items()
795
- if site["type"] == "sample" and site["is_observed"]
796
- }
797
-
798
- def sanitize_chain(chain):
799
- """
800
- reshape the samples so that it is arviz compliant with an extra starting dimension
801
- """
802
- return tree_map(lambda x: x[None, ...], chain)
803
-
804
- inference_data = az.from_dict(
805
- sanitize_chain(posterior_samples),
806
- prior=sanitize_chain(prior),
807
- posterior_predictive=sanitize_chain(posterior_predictive),
808
- log_likelihood=sanitize_chain(log_likelihood),
809
- observed_data=observations,
810
- )
811
-
812
- inference_data = filter_inference_data(
813
- inference_data, self.observation_container, self.background_model
659
+ posterior = ns.get_samples(keys[1], num_samples=num_samples)
660
+ inference_data = self.build_inference_data(
661
+ posterior, num_chains=1, use_transformed_model=True
814
662
  )
815
663
 
816
664
  return FitResult(
jaxspec/model/abc.py CHANGED
@@ -7,9 +7,11 @@ import haiku as hk
7
7
  import jax
8
8
  import jax.numpy as jnp
9
9
  import networkx as nx
10
+ import rich
10
11
 
11
12
  from haiku._src import base
12
13
  from jax.scipy.integrate import trapezoid
14
+ from rich.table import Table
13
15
  from simpleeval import simple_eval
14
16
 
15
17
 
@@ -110,6 +112,30 @@ class SpectralModel:
110
112
  def params(self):
111
113
  return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))
112
114
 
115
+ def __rich_repr__(self):
116
+ table = Table(title=str(self))
117
+
118
+ table.add_column("Component", justify="right", style="bold", no_wrap=True)
119
+ table.add_column("Parameter")
120
+
121
+ params = self.params
122
+
123
+ for component in params.keys():
124
+ once = True
125
+
126
+ for parameters in params[component].keys():
127
+ table.add_row(component if once else "", parameters)
128
+ once = False
129
+
130
+ return table
131
+
132
+ def __repr_html_(self):
133
+ return self.__rich_repr__()
134
+
135
+ def __repr__(self):
136
+ rich.print(self.__rich_repr__())
137
+ return ""
138
+
113
139
  def photon_flux(self, params, e_low, e_high, n_points=2):
114
140
  r"""
115
141
  Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
jaxspec/model/additive.py CHANGED
@@ -14,8 +14,6 @@ from haiku.initializers import Constant as HaikuConstant
14
14
 
15
15
  from ..util.integrate import integrate_interval
16
16
  from ..util.online_storage import table_manager
17
-
18
- # from ._additive.apec import APEC
19
17
  from .abc import AdditiveComponent
20
18
 
21
19
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxspec
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Home-page: https://github.com/renecotyfanboy/jaxspec
6
6
  License: MIT
@@ -15,19 +15,18 @@ Requires-Dist: arviz (>=0.17.1,<0.20.0)
15
15
  Requires-Dist: astropy (>=6.0.0,<7.0.0)
16
16
  Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
17
17
  Requires-Dist: cmasher (>=1.6.3,<2.0.0)
18
- Requires-Dist: dm-haiku (>=0.0.11,<0.0.13)
18
+ Requires-Dist: dm-haiku (>=0.0.12,<0.0.13)
19
19
  Requires-Dist: gpjax (>=0.8.0,<0.9.0)
20
20
  Requires-Dist: interpax (>=0.3.3,<0.4.0)
21
- Requires-Dist: jax (>=0.4.30,<0.5.0)
21
+ Requires-Dist: jax (>=0.4.33,<0.5.0)
22
22
  Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
23
- Requires-Dist: jaxns (>=2.5.1,<3.0.0)
23
+ Requires-Dist: jaxns (<2.6)
24
24
  Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
25
25
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
26
26
  Requires-Dist: mendeleev (>=0.15,<0.18)
27
- Requires-Dist: mkdocstrings (>=0.24,<0.27)
28
27
  Requires-Dist: networkx (>=3.1,<4.0)
29
28
  Requires-Dist: numpy (<2.0.0)
30
- Requires-Dist: numpyro (>=0.15.2,<0.16.0)
29
+ Requires-Dist: numpyro (>=0.15.3,<0.16.0)
31
30
  Requires-Dist: optimistix (>=0.0.7,<0.0.8)
32
31
  Requires-Dist: pandas (>=2.2.0,<3.0.0)
33
32
  Requires-Dist: pooch (>=1.8.2,<2.0.0)
@@ -41,7 +40,14 @@ Requires-Dist: watermark (>=2.4.3,<3.0.0)
41
40
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
42
41
  Description-Content-Type: text/markdown
43
42
 
44
- # jaxspec
43
+ <p align="center">
44
+ <img src="https://raw.githubusercontent.com/renecotyfanboy/jaxspec/main/docs/logo/logo_small.svg" alt="Logo" width="100" height="100">
45
+ </p>
46
+
47
+ <h1 align="center">
48
+ jaxspec
49
+ </h1>
50
+
45
51
 
46
52
  [![PyPI - Version](https://img.shields.io/pypi/v/jaxspec?style=for-the-badge&logo=pypi&color=rgb(37%2C%20150%2C%20190))](https://pypi.org/project/jaxspec/)
47
53
  [![Python package](https://img.shields.io/pypi/pyversions/jaxspec?style=for-the-badge)](https://pypi.org/project/jaxspec/)
@@ -1,21 +1,19 @@
1
1
  jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
2
2
  jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ jaxspec/analysis/_plot.py,sha256=C4XljmuzQz8xQur_jQddgInrBDmKgTn0eugSreLoD5k,862
3
4
  jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
4
5
  jaxspec/analysis/results.py,sha256=Kz3eryxS3N_hiajcFLTWS1dtgTQo5hlh-rDCnJ3A-3c,27811
5
6
  jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
6
7
  jaxspec/data/grouping.py,sha256=hhgBt-voiH0DDSyePacaIGsaMnrYbJM_-ZeU66keC7I,622
7
8
  jaxspec/data/instrument.py,sha256=0pSf1p82g7syDMmKm13eVbYih-Veiq5DnwsyZe6_b4g,3890
8
- jaxspec/data/obsconf.py,sha256=0X9jR-pV-Pk4-EVuUdlVWgl_gBx8ZurVkRNrfKQWdC4,8663
9
+ jaxspec/data/obsconf.py,sha256=gv14sL6azK2avRiMCWuTbyLBPulzm4PwvoLY6iWPEVE,9833
9
10
  jaxspec/data/observation.py,sha256=1UnFu5ihZp9z-vP_I7tsFY8jhhIJunv46JyuE-acrg0,6394
10
11
  jaxspec/data/ogip.py,sha256=sv9p00qHS5pzw61pzWyyF0nV-E-RXySdSFK2tUavokA,9545
11
12
  jaxspec/data/util.py,sha256=ycLPVE-cjn6VpUWYlBU1BGfw73ANXIBilyVAUOYOSj0,9540
12
- jaxspec/fit.py,sha256=lfeqn1HFlZSFSvVbBL0obV9cCwMlbBWpJz_nw7JI0WY,29552
13
+ jaxspec/fit.py,sha256=hI0koMO4KsNpe9mLlaFm_tNLgm4BVAYVyiMb1E1eyZE,24553
13
14
  jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- jaxspec/model/_additive/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- jaxspec/model/_additive/apec.py,sha256=r7CQqscAgR0BXC_AJqF6B7CPq3Byoo65Z-h9XgACZeU,12460
16
- jaxspec/model/_additive/apec_loaders.py,sha256=jkUoH0ezeYdaNw3oV10V0L-jt848SKp2thanLWLWp9k,2412
17
- jaxspec/model/abc.py,sha256=nQZUmtUzXjW94gv3BJg1lHXHZtgrHoOlAR4a6G2a9VQ,20234
18
- jaxspec/model/additive.py,sha256=xD5E30nd5pqa-swQireA52ch1czxnqRosnh-dsp5xL0,22485
15
+ jaxspec/model/abc.py,sha256=MuxEyvn223QPwGoFIJiST8nRMgrZ08ZLkw33oep3tx4,20887
16
+ jaxspec/model/additive.py,sha256=wjY2wL3Io3F45GJpz-UB8xYVnA-W1OFBnZMbj5pWPbQ,22449
19
17
  jaxspec/model/background.py,sha256=QSFFiuyUEvuzXBx3QfkvVneUR8KKEP-VaANEVXcavDE,7865
20
18
  jaxspec/model/list.py,sha256=0RPAoscVz_zM1CWdx_Gd5wfrQWV5Nv4Kd4bSXu2ayUA,860
21
19
  jaxspec/model/multiplicative.py,sha256=GCQ6JRz92QqbzDBFwWxGZ9SUqTJZQpD7B6ji9VEFXWo,8135
@@ -26,8 +24,8 @@ jaxspec/util/abundance.py,sha256=fsC313taIlGzQsZNwbYsJupDWm7ZbqzGhY66Ku394Mw,854
26
24
  jaxspec/util/integrate.py,sha256=_Ax_knpC7d4et2-QFkOUzVtNeQLX1-cwLvm-FRBxYcw,4505
27
25
  jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
28
26
  jaxspec/util/typing.py,sha256=8qK1aJlsqTcVKjYN-BxsDx20BTwtnS-wMw6Bdurpm-o,2459
29
- jaxspec-0.1.1.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
30
- jaxspec-0.1.1.dist-info/METADATA,sha256=vQMbkUPdTyiuLuHxEGLEitAg-2GAk65lvIPQJDfutY8,3572
31
- jaxspec-0.1.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
- jaxspec-0.1.1.dist-info/entry_points.txt,sha256=kzLG2mGlCWITRn4Q6zKG_idx-_RKAncvA0DMNYTgHAg,71
33
- jaxspec-0.1.1.dist-info/RECORD,,
27
+ jaxspec-0.1.2.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
28
+ jaxspec-0.1.2.dist-info/METADATA,sha256=FE2bTAk-3Xryi6fplV4Y-F2eibUdLZgC9ET9_4HvdOA,3708
29
+ jaxspec-0.1.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
30
+ jaxspec-0.1.2.dist-info/entry_points.txt,sha256=kzLG2mGlCWITRn4Q6zKG_idx-_RKAncvA0DMNYTgHAg,71
31
+ jaxspec-0.1.2.dist-info/RECORD,,
File without changes
@@ -1,316 +0,0 @@
1
- import warnings
2
-
3
- from typing import Literal
4
-
5
- import astropy.units as u
6
- import haiku as hk
7
- import jax
8
- import jax.numpy as jnp
9
-
10
- from astropy.constants import c, m_p
11
- from haiku.initializers import Constant as HaikuConstant
12
- from jax import lax
13
- from jax.lax import fori_loop, scan
14
- from jax.scipy.stats import norm as gaussian
15
-
16
- from ...util.abundance import abundance_table, element_data
17
- from ..abc import AdditiveComponent
18
- from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
19
-
20
-
21
- @jax.jit
22
- def lerp(x, x0, x1, y0, y1):
23
- """
24
- Linear interpolation routine
25
- Return y(x) = (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
26
- """
27
- return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
28
-
29
-
30
- @jax.jit
31
- def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end_index):
32
- """
33
- This function interpolate & integrate the values of a tabulated reference continuum between two energy limits
34
- Sorry for the boilerplate here, but be sure that it works !
35
-
36
- Parameters:
37
- energy_low: lower limit of the integral
38
- energy_high: upper limit of the integral
39
- energy_ref: energy grid of the reference continuum
40
- continuum_ref: continuum values evaluated at energy_ref
41
-
42
- """
43
- energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
44
- start_index = jnp.searchsorted(energy_ref, energy_low, side="left") - 1
45
- end_index = jnp.searchsorted(energy_ref, energy_high, side="left") + 1
46
-
47
- def body_func(index, value):
48
- integrated_flux, previous_energy, previous_continuum = value
49
- current_energy, current_continuum = energy_ref[index], continuum_ref[index]
50
-
51
- # 5 cases
52
- # Neither current and previous energies are within the integral limits > nothing is added to the integrated flux
53
- # The left limit of the integral is between the current and previous energy > previous energy is set to the limit, previous continuum is interpolated, and then added to the integrated flux
54
- # The right limit of the integral is between the current and previous energy > current energy is set to the limit, current continuum is interpolated, and then added to the integrated flux
55
- # Both current and previous energies are within the integral limits -> add to the integrated flux
56
- # Within
57
-
58
- current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
59
- previous_energy_is_between = (energy_low <= previous_energy) * (
60
- previous_energy < energy_high
61
- )
62
- energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
63
-
64
- case = (
65
- (1 - previous_energy_is_between) * current_energy_is_between * 1
66
- + previous_energy_is_between * (1 - current_energy_is_between) * 2
67
- + (previous_energy_is_between * current_energy_is_between) * 3
68
- + energies_within_bins * 4
69
- )
70
-
71
- term_to_add = lax.switch(
72
- case,
73
- [
74
- lambda pe, pc, ce, cc, el, er: 0.0, # 1
75
- lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
76
- lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
77
- lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
78
- lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
79
- * (er - el)
80
- / 2,
81
- # 5
82
- ],
83
- previous_energy,
84
- previous_continuum,
85
- current_energy,
86
- current_continuum,
87
- energy_low,
88
- energy_high,
89
- )
90
-
91
- return integrated_flux + term_to_add, current_energy, current_continuum
92
-
93
- integrated_flux, _, _ = fori_loop(start_index, end_index, body_func, (0.0, 0.0, 0.0))
94
-
95
- return integrated_flux
96
-
97
-
98
- @jax.jit
99
- def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
100
- energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
101
-
102
- return (
103
- jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
104
- ) / (e_high - e_low)
105
-
106
-
107
- @jax.jit
108
- def interp_flux(energy, energy_ref, continuum_ref, end_index):
109
- """
110
- Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
111
- """
112
-
113
- def scanned_func(carry, unpack):
114
- e_low, e_high = unpack
115
- continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
116
-
117
- return carry, continuum
118
-
119
- _, continuum = scan(scanned_func, 0.0, (energy[:-1], energy[1:]))
120
-
121
- return continuum
122
-
123
-
124
- @jax.jit
125
- def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
126
- """
127
- Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
128
- and weight the flux depending on the abundance of each element
129
- """
130
-
131
- def scanned_func(_, unpack):
132
- energy_ref, continuum_ref, end_idx = unpack
133
- element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
134
-
135
- return _, element_flux
136
-
137
- _, flux = scan(scanned_func, 0.0, (energy_ref, continuum_ref, end_index))
138
-
139
- return abundances @ flux
140
-
141
-
142
- @jax.jit
143
- def get_lines_contribution_broadening(
144
- line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
145
- ):
146
- def body_func(i, flux):
147
- # Notice the -1 in line element to match the 0-based indexing
148
- l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
149
- broadening = l_energy * total_broadening[l_element]
150
- l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
151
- energy[:-1], l_energy, broadening
152
- )
153
- l_flux = l_flux * l_emissivity * abundances[l_element]
154
-
155
- return flux + l_flux
156
-
157
- return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
158
-
159
-
160
- @jax.jit
161
- def continuum_func(energy, kT, abundances):
162
- idx, kT_low, kT_high = get_temperature(kT)
163
- continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
164
- continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
165
-
166
- return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
167
-
168
-
169
- @jax.jit
170
- def pseudo_func(energy, kT, abundances):
171
- idx, kT_low, kT_high = get_temperature(kT)
172
- continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
173
- continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
174
-
175
- return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
176
-
177
-
178
- # @jax.custom_jvp
179
- @jax.jit
180
- def lines_func(energy, kT, abundances, broadening):
181
- idx, kT_low, kT_high = get_temperature(kT)
182
- line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
183
- line_high = get_lines_contribution_broadening(
184
- *get_lines(idx + 1), energy, abundances, broadening
185
- )
186
-
187
- return lerp(kT, kT_low, kT_high, line_low, line_high)
188
-
189
-
190
- class APEC(AdditiveComponent):
191
- """
192
- APEC model implementation in pure JAX for X-ray spectral fitting.
193
-
194
- !!! warning
195
- This implementation is optimised for the CPU, it shows poor performance on the GPU.
196
- """
197
-
198
- def __init__(
199
- self,
200
- continuum: bool = True,
201
- pseudo: bool = True,
202
- lines: bool = True,
203
- thermal_broadening: bool = True,
204
- turbulent_broadening: bool = True,
205
- variant: Literal["none", "v", "vv"] = "none",
206
- abundance_table: Literal[
207
- "angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
208
- ] = "angr",
209
- trace_abundance: float = 1.0,
210
- **kwargs,
211
- ):
212
- super().__init__(**kwargs)
213
-
214
- warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
215
-
216
- self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
217
-
218
- self.abundance_table = abundance_table
219
- self.thermal_broadening = thermal_broadening
220
- self.turbulent_broadening = turbulent_broadening
221
- self.continuum_to_compute = continuum
222
- self.pseudo_to_compute = pseudo
223
- self.lines_to_compute = lines
224
- self.trace_abundance = trace_abundance
225
- self.variant = variant
226
-
227
- def get_thermal_broadening(self):
228
- r"""
229
- Compute the thermal broadening $\sigma_T$ for each element using :
230
-
231
- $$ \frac{\sigma_T}{E_{\text{line}}} = \frac{1}{c}\sqrt{\frac{k_{B} T}{A m_p}}$$
232
-
233
- where $E_{\text{line}}$ is the energy of the line, $c$ is the speed of light, $k_{B}$ is the Boltzmann constant,
234
- $T$ is the temperature, $A$ is the atomic weight of the element and $m_p$ is the proton mass.
235
- """
236
-
237
- if self.thermal_broadening:
238
- kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
239
- factor = 1 / c * (1 / m_p) ** (1 / 2)
240
- factor = factor.to(u.keV ** (-1 / 2)).value
241
-
242
- # Multiply this factor by Line_Energy * sqrt(kT/A) to get the broadening for a line
243
- # This return value must be multiplied by the energy of the line to get actual broadening
244
- return factor * jnp.sqrt(kT / self.atomic_weights)
245
-
246
- else:
247
- return jnp.zeros((30,))
248
-
249
- def get_turbulent_broadening(self):
250
- r"""
251
- Return the turbulent broadening using :
252
-
253
- $$\frac{\sigma_\text{turb}}{E_{\text{line}}} = \frac{\sigma_{v ~ ||}}{c}$$
254
-
255
- where $\sigma_{v ~ ||}$ is the velocity dispersion along the line of sight in km/s.
256
- """
257
- if self.turbulent_broadening:
258
- # This return value must be multiplied by the energy of the line to get actual broadening
259
- return (
260
- hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
261
- )
262
- else:
263
- return 0.0
264
-
265
- def get_parameters(self):
266
- none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
267
- v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
268
- trace_elements = (
269
- jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
270
- )
271
-
272
- # Set abundances of trace element (will be overwritten in the vv case)
273
- abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
274
-
275
- if self.variant == "vv":
276
- for i, element in enumerate(abundance_table["Element"]):
277
- if element != "H":
278
- abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
279
-
280
- elif self.variant == "v":
281
- for i, element in enumerate(abundance_table["Element"]):
282
- if element != "H" and element in v_elements:
283
- abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
284
-
285
- else:
286
- Z = hk.get_parameter("Abundance", [], init=HaikuConstant(1.0))
287
- for i, element in enumerate(abundance_table["Element"]):
288
- if element != "H" and element in none_elements:
289
- abund = abund.at[i].set(Z)
290
-
291
- if abund != "angr":
292
- abund = abund * jnp.asarray(
293
- abundance_table[self.abundance_table] / abundance_table["angr"]
294
- )
295
-
296
- # Set the temperature, redshift, normalisation
297
- kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
298
- z = hk.get_parameter("Redshift", [], init=HaikuConstant(0.0))
299
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1.0))
300
-
301
- return kT, z, norm, abund
302
-
303
- def emission_lines(self, e_low, e_high):
304
- # Get the parameters and extract the relevant data
305
- energy = jnp.hstack([e_low, e_high[-1]])
306
- kT, z, norm, abundances = self.get_parameters()
307
- total_broadening = jnp.hypot(self.get_thermal_broadening(), self.get_turbulent_broadening())
308
- energy = energy * (1 + z)
309
-
310
- continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
311
- pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
312
- lines = (
313
- lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
314
- )
315
-
316
- return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
@@ -1,73 +0,0 @@
1
- """This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
2
- pure callback to enable reading data from the files without saturating the memory."""
3
-
4
- import h5netcdf
5
- import jax
6
- import jax.numpy as jnp
7
-
8
- from ...util.online_storage import table_manager
9
-
10
-
11
- @jax.jit
12
- def temperature_table_getter():
13
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
14
- temperature = jnp.asarray(f["/temperature"])
15
-
16
- return temperature
17
-
18
-
19
- @jax.jit
20
- def get_temperature(kT):
21
- temperature = temperature_table_getter()
22
- idx = jnp.searchsorted(temperature, kT) - 1
23
-
24
- return idx, temperature[idx], temperature[idx + 1]
25
-
26
-
27
- @jax.jit
28
- def continuum_table_getter():
29
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
30
- continuum_energy = jnp.asarray(f["/continuum_energy"])
31
- continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
32
- continuum_end_index = jnp.asarray(f["/continuum_end_index"])
33
-
34
- return continuum_energy, continuum_emissivity, continuum_end_index
35
-
36
-
37
- @jax.jit
38
- def pseudo_table_getter():
39
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
40
- pseudo_energy = jnp.asarray(f["/pseudo_energy"])
41
- pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
42
- pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
43
-
44
- return pseudo_energy, pseudo_emissivity, pseudo_end_index
45
-
46
-
47
- @jax.jit
48
- def line_table_getter():
49
- with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
50
- line_energy = jnp.asarray(f["/line_energy"])
51
- line_element = jnp.asarray(f["/line_element"])
52
- line_emissivity = jnp.asarray(f["/line_emissivity"])
53
- line_end_index = jnp.asarray(f["/line_end_index"])
54
-
55
- return line_energy, line_element, line_emissivity, line_end_index
56
-
57
-
58
- @jax.jit
59
- def get_continuum(idx):
60
- continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
61
- return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
62
-
63
-
64
- @jax.jit
65
- def get_pseudo(idx):
66
- pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
67
- return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
68
-
69
-
70
- @jax.jit
71
- def get_lines(idx):
72
- line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
73
- return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]