jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,10 @@
1
1
  import os
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
5
+
4
6
  from matplotlib import colors
7
+
5
8
  from .ogip import DataARF, DataRMF
6
9
 
7
10
 
@@ -25,7 +28,15 @@ class Instrument(xr.Dataset):
25
28
  )
26
29
 
27
30
  @classmethod
28
- def from_matrix(cls, redistribution_matrix, spectral_response, e_min_unfolded, e_max_unfolded, e_min_channel, e_max_channel):
31
+ def from_matrix(
32
+ cls,
33
+ redistribution_matrix,
34
+ spectral_response,
35
+ e_min_unfolded,
36
+ e_max_unfolded,
37
+ e_min_channel,
38
+ e_max_channel,
39
+ ):
29
40
  return cls(
30
41
  {
31
42
  "redistribution": (
@@ -65,7 +76,7 @@ class Instrument(xr.Dataset):
65
76
  )
66
77
 
67
78
  @classmethod
68
- def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike = None):
79
+ def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike | None = None):
69
80
  """
70
81
  Load the data from OGIP files.
71
82
 
@@ -80,37 +91,61 @@ class Instrument(xr.Dataset):
80
91
  specresp = DataARF.from_file(arf_path).specresp
81
92
 
82
93
  else:
83
- specresp = np.ones(rmf.energ_lo.shape)
94
+ specresp = rmf.matrix.sum(axis=0)
95
+ rmf.sparse_matrix /= specresp
84
96
 
85
- return cls.from_matrix(rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max)
97
+ return cls.from_matrix(
98
+ rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
99
+ )
86
100
 
87
- def plot_redistribution(self, **kwargs):
101
+ def plot_redistribution(
102
+ self,
103
+ xscale: str = "log",
104
+ yscale: str = "log",
105
+ cmap=None,
106
+ vmin: float = 1e-6,
107
+ vmax: float = 1e0,
108
+ add_labels: bool = True,
109
+ **kwargs,
110
+ ):
88
111
  """
89
112
  Plot the redistribution probability matrix
90
113
 
91
114
  Parameters:
115
+ xscale : The scale of the x-axis.
116
+ yscale : The scale of the y-axis.
117
+ cmap : The colormap to use.
118
+ vmin : The minimum value for the colormap.
119
+ vmax : The maximum value for the colormap.
120
+ add_labels : Whether to add labels to the plot.
92
121
  **kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.plot.pcolormesh.html#xarray.plot.pcolormesh
93
122
  """
123
+
94
124
  import cmasher as cmr
95
125
 
96
126
  return xr.plot.pcolormesh(
97
127
  self.redistribution,
98
128
  x="e_max_unfolded",
99
129
  y="e_max_channel",
100
- xscale="log",
101
- yscale="log",
102
- cmap=cmr.ember_r,
103
- norm=colors.LogNorm(vmin=1e-6, vmax=1),
104
- add_labels=True,
130
+ xscale=xscale,
131
+ yscale=yscale,
132
+ cmap=cmr.ember_r if cmap is None else cmap,
133
+ norm=colors.LogNorm(vmin=vmin, vmax=vmax),
134
+ add_labels=add_labels,
105
135
  **kwargs,
106
136
  )
107
137
 
108
- def plot_area(self, **kwargs):
138
+ def plot_area(self, xscale: str = "log", yscale: str = "log", where: str = "post", **kwargs):
109
139
  """
110
140
  Plot the effective area
111
141
 
112
142
  Parameters:
143
+ xscale : The scale of the x-axis.
144
+ yscale : The scale of the y-axis.
145
+ where : The position of the steps.
113
146
  **kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.DataArray.plot.line.html#xarray.DataArray.plot.line
114
147
  """
115
148
 
116
- return self.area.plot.step(x="e_min_unfolded", xscale="log", yscale="log", where="post", **kwargs)
149
+ return self.area.plot.step(
150
+ x="e_min_unfolded", xscale=xscale, yscale=yscale, where=where, **kwargs
151
+ )
jaxspec/data/obsconf.py CHANGED
@@ -157,8 +157,10 @@ class ObsConfiguration(xr.Dataset):
157
157
 
158
158
  if observation.folded_background is not None:
159
159
  folded_background = observation.folded_background.data[row_idx]
160
+ folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
160
161
  else:
161
162
  folded_background = np.zeros_like(folded_counts)
163
+ folded_background_unscaled = np.zeros_like(folded_counts)
162
164
 
163
165
  data_dict = {
164
166
  "transfer_matrix": (
@@ -200,6 +202,14 @@ class ObsConfiguration(xr.Dataset):
200
202
  "unit": "photons",
201
203
  },
202
204
  ),
205
+ "folded_background_unscaled": (
206
+ ["folded_channel"],
207
+ folded_background_unscaled,
208
+ {
209
+ "description": "To be done",
210
+ "unit": "photons",
211
+ },
212
+ ),
203
213
  }
204
214
 
205
215
  return cls(
@@ -234,8 +244,8 @@ class ObsConfiguration(xr.Dataset):
234
244
  cls,
235
245
  instrument: Instrument,
236
246
  exposure: float,
237
- low_energy: float = 1e-20,
238
- high_energy: float = 1e20,
247
+ low_energy: float = 1e-300,
248
+ high_energy: float = 1e300,
239
249
  ):
240
250
  """
241
251
  Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
@@ -46,14 +46,16 @@ class Observation(xr.Dataset):
46
46
  quality,
47
47
  exposure,
48
48
  background=None,
49
+ background_unscaled=None,
49
50
  backratio=1.0,
50
51
  attributes: dict | None = None,
51
52
  ):
52
53
  if attributes is None:
53
54
  attributes = {}
54
55
 
55
- if background is None:
56
+ if background is None or background_unscaled is None:
56
57
  background = np.zeros_like(counts, dtype=np.int64)
58
+ background_unscaled = np.zeros_like(counts, dtype=np.int64)
57
59
 
58
60
  data_dict = {
59
61
  "counts": (
@@ -92,9 +94,14 @@ class Observation(xr.Dataset):
92
94
  np.asarray(background, dtype=np.int64),
93
95
  {"description": "Background counts", "unit": "photons"},
94
96
  ),
97
+ "folded_background_unscaled": (
98
+ ["folded_channel"],
99
+ np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
100
+ {"description": "Background counts", "unit": "photons"},
101
+ ),
95
102
  "folded_background": (
96
103
  ["folded_channel"],
97
- np.asarray(np.ma.filled(grouping @ background), dtype=np.int64),
104
+ np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
98
105
  {"description": "Background counts", "unit": "photons"},
99
106
  ),
100
107
  }
@@ -128,14 +135,20 @@ class Observation(xr.Dataset):
128
135
  else:
129
136
  backratio = np.ones_like(pha.counts)
130
137
 
138
+ if (bkg is not None) and ("NET" in pha.flags):
139
+ counts = pha.counts + bkg.counts * backratio
140
+ else:
141
+ counts = pha.counts
142
+
131
143
  return cls.from_matrix(
132
- pha.counts,
144
+ counts,
133
145
  pha.grouping,
134
146
  pha.channel,
135
147
  pha.quality,
136
148
  pha.exposure,
137
149
  backratio=backratio,
138
- background=bkg.counts if bkg is not None else None,
150
+ background=bkg.counts * backratio if bkg is not None else None,
151
+ background_unscaled=bkg.counts if bkg is not None else None,
139
152
  attributes=metadata,
140
153
  )
141
154
 
jaxspec/data/ogip.py CHANGED
@@ -1,9 +1,11 @@
1
- import numpy as np
2
1
  import os
2
+
3
3
  import astropy.units as u
4
+ import numpy as np
4
5
  import sparse
5
- from astropy.table import QTable
6
+
6
7
  from astropy.io import fits
8
+ from astropy.table import QTable
7
9
 
8
10
 
9
11
  class DataPHA:
@@ -25,6 +27,7 @@ class DataPHA:
25
27
  ancrfile=None,
26
28
  backscal=1.0,
27
29
  areascal=1.0,
30
+ flags=None,
28
31
  ):
29
32
  self.channel = np.asarray(channel, dtype=int)
30
33
  self.counts = np.asarray(counts, dtype=int)
@@ -36,6 +39,7 @@ class DataPHA:
36
39
  self.ancrfile = ancrfile
37
40
  self.backscal = np.asarray(backscal, dtype=float)
38
41
  self.areascal = np.asarray(areascal, dtype=float)
42
+ self.flags = flags
39
43
 
40
44
  if grouping is not None:
41
45
  # Indices array of the beginning of each group
@@ -55,7 +59,9 @@ class DataPHA:
55
59
  data.append(True)
56
60
 
57
61
  # Create a COO sparse matrix
58
- grp_matrix = sparse.COO((data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0)
62
+ grp_matrix = sparse.COO(
63
+ (data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0
64
+ )
59
65
 
60
66
  else:
61
67
  # Identity matrix case, use sparse for efficiency
@@ -74,12 +80,7 @@ class DataPHA:
74
80
 
75
81
  data = QTable.read(pha_file, "SPECTRUM")
76
82
  header = fits.getheader(pha_file, "SPECTRUM")
77
-
78
- if header.get("HDUCLAS2") == "NET":
79
- raise ValueError(
80
- f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
81
- f"Please open an issue if this is required."
82
- )
83
+ flags = []
83
84
 
84
85
  if header.get("HDUCLAS3") == "RATE":
85
86
  raise ValueError(
@@ -114,6 +115,8 @@ class DataPHA:
114
115
  else:
115
116
  raise ValueError("No BACKSCAL found in the PHA file.")
116
117
 
118
+ backscal = np.where(backscal == 0, 1.0, backscal)
119
+
117
120
  if "AREASCAL" in header:
118
121
  areascal = header["AREASCAL"]
119
122
  elif "AREASCAL" in data.colnames:
@@ -121,8 +124,9 @@ class DataPHA:
121
124
  else:
122
125
  raise ValueError("No AREASCAL found in the PHA file.")
123
126
 
124
- # Grouping and quality parameters are in binned PHA dataset
125
- # Backfile, respfile and ancrfile are in primary header
127
+ if header.get("HDUCLAS2") == "NET":
128
+ flags.append("NET")
129
+
126
130
  kwargs = {
127
131
  "grouping": grouping,
128
132
  "quality": quality,
@@ -131,6 +135,7 @@ class DataPHA:
131
135
  "ancrfile": header.get("ANCRFILE"),
132
136
  "backscal": backscal,
133
137
  "areascal": areascal,
138
+ "flags": flags,
134
139
  }
135
140
 
136
141
  return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
@@ -176,7 +181,19 @@ class DataRMF:
176
181
  * [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
177
182
  """
178
183
 
179
- def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max, low_threshold=0.0):
184
+ def __init__(
185
+ self,
186
+ energ_lo,
187
+ energ_hi,
188
+ n_grp,
189
+ f_chan,
190
+ n_chan,
191
+ matrix,
192
+ channel,
193
+ e_min,
194
+ e_max,
195
+ low_threshold=0.0,
196
+ ):
180
197
  # RMF stuff
181
198
  self.energ_lo = energ_lo # "Entry" energies
182
199
  self.energ_hi = energ_hi # "Entry" energies
@@ -229,7 +246,9 @@ class DataRMF:
229
246
  idxs = data > low_threshold
230
247
 
231
248
  # Create a COO sparse matrix and then convert to CSR for efficiency
232
- coo = sparse.COO([rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel)))
249
+ coo = sparse.COO(
250
+ [rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel))
251
+ )
233
252
  self.sparse_matrix = coo.T # .tocsr()
234
253
 
235
254
  @property
jaxspec/data/util.py CHANGED
@@ -2,16 +2,13 @@ from collections.abc import Mapping
2
2
  from pathlib import Path
3
3
  from typing import Literal, TypeVar
4
4
 
5
- import haiku as hk
6
5
  import jax
7
- import numpy as np
8
6
  import numpyro
9
7
 
10
8
  from astropy.io import fits
11
- from numpy.typing import ArrayLike
12
9
  from numpyro import handlers
13
10
 
14
- from .._fit._build_model import CountForwardModel
11
+ from .._fit._build_model import forward_model
15
12
  from ..model.abc import SpectralModel
16
13
  from ..util.online_storage import table_manager
17
14
  from . import Instrument, ObsConfiguration, Observation
@@ -127,67 +124,6 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
127
124
  raise ValueError(f"{source} not recognized.")
128
125
 
129
126
 
130
- def fakeit(
131
- instrument: ObsConfiguration | list[ObsConfiguration],
132
- model: SpectralModel,
133
- parameters: Mapping[K, V],
134
- rng_key: int = 0,
135
- sparsify_matrix: bool = False,
136
- ) -> ArrayLike | list[ArrayLike]:
137
- """
138
- Convenience function to simulate a spectrum from a given model and a set of parameters.
139
- It requires an instrumental setup, and unlike in
140
- [XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
141
- exclusively by Poisson statistics.
142
-
143
- Parameters:
144
- instrument: The instrumental setup.
145
- model: The model to use.
146
- parameters: The parameters of the model.
147
- rng_key: The random number generator seed.
148
- sparsify_matrix: Whether to sparsify the matrix or not.
149
- """
150
-
151
- instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
152
- fakeits = []
153
-
154
- for i, instrument in enumerate(instruments):
155
- transformed_model = hk.without_apply_rng(
156
- hk.transform(
157
- lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
158
- )
159
- )
160
-
161
- def obs_model(p):
162
- return transformed_model.apply(None, p)
163
-
164
- with handlers.seed(rng_seed=rng_key):
165
- counts = numpyro.sample(
166
- f"likelihood_obs_{i}",
167
- numpyro.distributions.Poisson(obs_model(parameters)),
168
- )
169
-
170
- """
171
- pha = DataPHA(
172
- instrument.rmf.channel,
173
- np.array(counts, dtype=int)*u.ct,
174
- instrument.exposure,
175
- grouping=instrument.grouping)
176
-
177
- observation = Observation(
178
- pha=pha,
179
- arf=instrument.arf,
180
- rmf=instrument.rmf,
181
- low_energy=instrument.low_energy,
182
- high_energy=instrument.high_energy
183
- )
184
- """
185
-
186
- fakeits.append(np.array(counts, dtype=int))
187
-
188
- return fakeits[0] if len(fakeits) == 1 else fakeits
189
-
190
-
191
127
  def fakeit_for_multiple_parameters(
192
128
  instrument: ObsConfiguration | list[ObsConfiguration],
193
129
  model: SpectralModel,
@@ -199,7 +135,6 @@ def fakeit_for_multiple_parameters(
199
135
  """
200
136
  Convenience function to simulate multiple spectra from a given model and a set of parameters.
201
137
 
202
- TODO : avoid redundancy, better doc and type hints
203
138
 
204
139
  Parameters:
205
140
  instrument: The instrumental setup.
@@ -214,24 +149,19 @@ def fakeit_for_multiple_parameters(
214
149
  fakeits = []
215
150
 
216
151
  for i, obs in enumerate(instruments):
217
- transformed_model = hk.without_apply_rng(
218
- hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparsify_matrix)(par))
152
+ countrate = jax.vmap(lambda p: forward_model(model, p, instrument, sparse=sparsify_matrix))(
153
+ parameters
219
154
  )
220
155
 
221
- @jax.jit
222
- @jax.vmap
223
- def obs_model(p):
224
- return transformed_model.apply(None, p)
225
-
226
156
  if apply_stat:
227
157
  with handlers.seed(rng_seed=rng_key):
228
158
  spectrum = numpyro.sample(
229
159
  f"likelihood_obs_{i}",
230
- numpyro.distributions.Poisson(obs_model(parameters)),
160
+ numpyro.distributions.Poisson(countrate),
231
161
  )
232
162
 
233
163
  else:
234
- spectrum = obs_model(parameters)
164
+ spectrum = countrate
235
165
 
236
166
  fakeits.append(spectrum)
237
167
 
jaxspec/fit.py CHANGED
@@ -10,7 +10,6 @@ import arviz as az
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import matplotlib.pyplot as plt
13
- import numpy as np
14
13
  import numpyro
15
14
 
16
15
  from jax import random
@@ -23,12 +22,16 @@ from numpyro.infer.reparam import TransformReparam
23
22
  from numpyro.infer.util import log_density
24
23
 
25
24
  from ._fit._build_model import build_prior, forward_model
26
- from .analysis._plot import _plot_poisson_data_with_error
25
+ from .analysis._plot import (
26
+ _error_bars_for_observed_data,
27
+ _plot_binned_samples_with_error,
28
+ _plot_poisson_data_with_error,
29
+ )
27
30
  from .analysis.results import FitResult
28
31
  from .data import ObsConfiguration
29
32
  from .model.abc import SpectralModel
30
33
  from .model.background import BackgroundModel
31
- from .util.typing import PriorDictModel, PriorDictType
34
+ from .util.typing import PriorDictType
32
35
 
33
36
 
34
37
  class BayesianModel:
@@ -63,10 +66,12 @@ class BayesianModel:
63
66
 
64
67
  if not callable(prior_distributions):
65
68
  # Validate the entry with pydantic
66
- prior = PriorDictModel.from_dict(prior_distributions).nested_dict
69
+ # prior = PriorDictModel.from_dict(prior_distributions).
67
70
 
68
71
  def prior_distributions_func():
69
- return build_prior(prior, expand_shape=(len(self.observation_container),))
72
+ return build_prior(
73
+ prior_distributions, expand_shape=(len(self.observation_container),)
74
+ )
70
75
 
71
76
  else:
72
77
  prior_distributions_func = prior_distributions
@@ -74,6 +79,22 @@ class BayesianModel:
74
79
  self.prior_distributions_func = prior_distributions_func
75
80
  self.init_params = self.prior_samples()
76
81
 
82
+ # Check the priors are suited for the observations
83
+ split_parameters = [
84
+ (param, shape[-1])
85
+ for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
86
+ if (len(shape) > 1)
87
+ and not param.startswith("_")
88
+ and not param.startswith("bkg") # hardcoded for subtracted background
89
+ ]
90
+
91
+ for parameter, proposed_number_of_obs in split_parameters:
92
+ if proposed_number_of_obs != len(self.observation_container):
93
+ raise ValueError(
94
+ f"Invalid splitting in the prior distribution. "
95
+ f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
96
+ )
97
+
77
98
  @cached_property
78
99
  def observation_container(self) -> dict[str, ObsConfiguration]:
79
100
  """
@@ -137,7 +158,9 @@ class BayesianModel:
137
158
  with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
138
159
  numpyro.sample(
139
160
  "obs_" + name,
140
- Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
161
+ Poisson(
162
+ obs_countrate + bkg_countrate
163
+ ), # / observation.folded_backratio.data
141
164
  obs=observation.folded_counts.data if observed else None,
142
165
  )
143
166
 
@@ -293,37 +316,26 @@ class BayesianModel:
293
316
  posterior_observations = self.mock_observations(prior_params, key=key_posterior)
294
317
 
295
318
  for key, value in self.observation_container.items():
296
- fig, axs = plt.subplots(
319
+ fig, ax = plt.subplots(
297
320
  nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
298
321
  )
299
322
 
300
- _plot_poisson_data_with_error(
301
- axs[0],
302
- value.out_energies,
303
- value.folded_counts.values,
304
- percentiles=(16, 84),
323
+ y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
324
+ value.folded_counts.values, 1.0, "ct"
305
325
  )
306
326
 
307
- for i, (envelop_percentiles, color, alpha) in enumerate(
308
- zip(
309
- [(16, 86), (2.5, 97.5), (0.15, 99.85)],
310
- ["#03045e", "#0077b6", "#00b4d8"],
311
- [0.5, 0.4, 0.3],
312
- )
313
- ):
314
- lower, upper = np.percentile(
315
- posterior_observations["obs_" + key], envelop_percentiles, axis=0
316
- )
327
+ true_data_plot = _plot_poisson_data_with_error(
328
+ ax[0],
329
+ value.out_energies,
330
+ y_observed.value,
331
+ y_observed_low.value,
332
+ y_observed_high.value,
333
+ alpha=0.7,
334
+ )
317
335
 
318
- axs[0].stairs(
319
- upper,
320
- edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
321
- baseline=lower,
322
- alpha=alpha,
323
- fill=True,
324
- color=color,
325
- label=rf"${1+i}\sigma$",
326
- )
336
+ prior_plot = _plot_binned_samples_with_error(
337
+ ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
338
+ )
327
339
 
328
340
  # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
329
341
  counts = posterior_observations["obs_" + key]
@@ -336,22 +348,22 @@ class BayesianModel:
336
348
 
337
349
  rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
338
350
 
339
- axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
351
+ ax[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
340
352
 
341
- axs[1].plot(
353
+ ax[1].plot(
342
354
  (value.out_energies.min(), value.out_energies.max()),
343
355
  (50, 50),
344
356
  color="black",
345
357
  linestyle="--",
346
358
  )
347
359
 
348
- axs[1].set_xlabel("Energy (keV)")
349
- axs[0].set_ylabel("Counts")
350
- axs[1].set_ylabel("Rank (%)")
351
- axs[1].set_ylim(0, 100)
352
- axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
353
- axs[0].loglog()
354
- axs[0].legend(loc="upper right")
360
+ ax[1].set_xlabel("Energy (keV)")
361
+ ax[0].set_ylabel("Counts")
362
+ ax[1].set_ylabel("Rank (%)")
363
+ ax[1].set_ylim(0, 100)
364
+ ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
365
+ ax[0].loglog()
366
+ ax[0].legend(loc="upper right")
355
367
  plt.suptitle(f"Prior Predictive coverage for {key}")
356
368
  plt.tight_layout()
357
369
  plt.show()
@@ -544,7 +556,6 @@ class MCMCFitter(BayesianModelFitter):
544
556
  return FitResult(
545
557
  self,
546
558
  inference_data,
547
- self.model.params,
548
559
  background_model=self.background_model,
549
560
  )
550
561
 
@@ -590,11 +601,13 @@ class NSFitter(BayesianModelFitter):
590
601
  ns = NestedSampler(
591
602
  bayesian_model,
592
603
  constructor_kwargs=dict(
593
- num_parallel_workers=1,
594
604
  verbose=verbose,
595
605
  difficult_model=True,
596
- max_samples=1e6,
606
+ max_samples=1e5,
597
607
  parameter_estimation=True,
608
+ gradient_guided=True,
609
+ devices=jax.devices(),
610
+ # init_efficiency_threshold=0.01,
598
611
  num_live_points=num_live_points,
599
612
  ),
600
613
  termination_kwargs=termination_kwargs if termination_kwargs else dict(),
@@ -613,6 +626,5 @@ class NSFitter(BayesianModelFitter):
613
626
  return FitResult(
614
627
  self,
615
628
  inference_data,
616
- self.model.params,
617
629
  background_model=self.background_model,
618
630
  )