jaxspec 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Mapping
2
- from typing import Any, Literal, TypeVar
4
+ from typing import TYPE_CHECKING, Any, Literal, TypeVar
3
5
 
4
6
  import arviz as az
5
7
  import astropy.units as u
@@ -14,13 +16,14 @@ from astropy.units import Unit
14
16
  from chainconsumer import Chain, ChainConsumer, PlotConfig
15
17
  from haiku.data_structures import traverse
16
18
  from jax.typing import ArrayLike
19
+ from numpyro.handlers import seed
17
20
  from scipy.integrate import trapezoid
18
21
  from scipy.special import gammaln
19
22
  from scipy.stats import nbinom
20
23
 
21
- from ..data import ObsConfiguration
22
- from ..model.abc import SpectralModel
23
- from ..model.background import BackgroundModel
24
+ if TYPE_CHECKING:
25
+ from ..fit import BayesianModel
26
+ from ..model.background import BackgroundModel
24
27
 
25
28
  K = TypeVar("K")
26
29
  V = TypeVar("V")
@@ -96,18 +99,15 @@ class FitResult:
96
99
  # TODO : Add type hints
97
100
  def __init__(
98
101
  self,
99
- model: SpectralModel,
100
- obsconf: ObsConfiguration | dict[str, ObsConfiguration],
102
+ bayesian_fitter: BayesianModel,
101
103
  inference_data: az.InferenceData,
102
104
  structure: Mapping[K, V],
103
105
  background_model: BackgroundModel = None,
104
106
  ):
105
- self.model = model
106
- self._structure = structure
107
+ self.model = bayesian_fitter.model
108
+ self.bayesian_fitter = bayesian_fitter
107
109
  self.inference_data = inference_data
108
- self.obsconfs = (
109
- {"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
110
- )
110
+ self.obsconfs = bayesian_fitter.observation_container
111
111
  self.background_model = background_model
112
112
  self._structure = structure
113
113
 
@@ -115,7 +115,7 @@ class FitResult:
115
115
  for group in self.inference_data.groups():
116
116
  group_name = group.split("/")[-1]
117
117
  metadata = getattr(self.inference_data, group_name).attrs
118
- metadata["model"] = str(model)
118
+ metadata["model"] = str(self.model)
119
119
  # TODO : Store metadata about observations used in the fitting process
120
120
 
121
121
  @property
@@ -132,9 +132,7 @@ class FitResult:
132
132
  Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
133
133
  """
134
134
 
135
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
136
- posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
137
- samples_flat = {key: posterior[key].data for key in var_names}
135
+ samples_flat = self._structured_samples_flat
138
136
 
139
137
  samples_haiku = {}
140
138
 
@@ -145,6 +143,60 @@ class FitResult:
145
143
 
146
144
  return samples_haiku
147
145
 
146
+ @property
147
+ def _structured_samples_flat(self):
148
+ """
149
+ Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
150
+ """
151
+
152
+ var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
153
+ posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
154
+ samples_flat = {key: posterior[key].data for key in var_names}
155
+
156
+ return samples_flat
157
+
158
+ @property
159
+ def input_parameters(self) -> HaikuDict[ArrayLike]:
160
+ """
161
+ The input parameters of the model.
162
+ """
163
+
164
+ posterior = az.extract(self.inference_data, combined=False)
165
+
166
+ samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
167
+
168
+ total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
169
+
170
+ posterior = {key: posterior[key].data for key in posterior.data_vars}
171
+
172
+ with seed(rng_seed=0):
173
+ input_parameters = self.bayesian_fitter.prior_distributions_func()
174
+
175
+ for module, parameter, value in traverse(input_parameters):
176
+ if f"{module}_{parameter}" in posterior.keys():
177
+ # We add as extra dimension as there might be different values per observation
178
+ if posterior[f"{module}_{parameter}"].shape == samples_shape:
179
+ to_set = posterior[f"{module}_{parameter}"][..., None]
180
+ else:
181
+ to_set = posterior[f"{module}_{parameter}"]
182
+
183
+ input_parameters[module][parameter] = to_set
184
+
185
+ else:
186
+ # The parameter is fixed in this case, so we just broadcast is over chain and draws
187
+ input_parameters[module][parameter] = value[None, None, ...]
188
+
189
+ if len(total_shape) < len(input_parameters[module][parameter].shape):
190
+ # If there are only chains and draws, we reduce
191
+ input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
192
+
193
+ else:
194
+ input_parameters[module][parameter] = jnp.broadcast_to(
195
+ input_parameters[module][parameter], total_shape
196
+ )
197
+
198
+ return input_parameters
199
+
148
200
  def photon_flux(
149
201
  self,
150
202
  e_min: float,
@@ -156,8 +208,7 @@ class FitResult:
156
208
  Compute the unfolded photon flux in a given energy band. The flux is then added to
157
209
  the result parameters so covariance can be plotted.
158
210
 
159
- Parameters
160
- ----------
211
+ Parameters:
161
212
  e_min: The lower bound of the energy band in observer frame.
162
213
  e_max: The upper bound of the energy band in observer frame.
163
214
  unit: The unit of the photon flux.
@@ -168,20 +219,22 @@ class FitResult:
168
219
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
169
220
  """
170
221
 
171
- samples = self._structured_samples
172
- init_shape = jax.tree.leaves(samples)[0].shape
173
-
174
- flux = jax.vmap(
175
- lambda p: self.model.photon_flux(p, jnp.asarray([e_min]), jnp.asarray([e_max]))
176
- )(jax.tree.map(lambda x: x.ravel(), samples))
222
+ @jax.jit
223
+ @jnp.vectorize
224
+ def vectorized_flux(*pars):
225
+ parameters_pytree = jax.tree.unflatten(pytree_def, pars)
226
+ return self.model.photon_flux(
227
+ parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
228
+ )[0]
177
229
 
178
- flux = jax.tree.map(lambda x: x.reshape(init_shape), flux)
230
+ flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
231
+ flux = vectorized_flux(*flat_tree)
179
232
  conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
180
233
  value = flux * conversion_factor
181
234
 
182
235
  if register:
183
- self.inference_data.posterior[f"flux_{e_min:.1f}_{e_max:.1f}"] = (
184
- ["chain", "draw"],
236
+ self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
237
+ list(self.inference_data.posterior.coords),
185
238
  value,
186
239
  )
187
240
 
@@ -198,8 +251,7 @@ class FitResult:
198
251
  Compute the unfolded energy flux in a given energy band. The flux is then added to
199
252
  the result parameters so covariance can be plotted.
200
253
 
201
- Parameters
202
- ----------
254
+ Parameters:
203
255
  e_min: The lower bound of the energy band in observer frame.
204
256
  e_max: The upper bound of the energy band in observer frame.
205
257
  unit: The unit of the energy flux.
@@ -210,21 +262,22 @@ class FitResult:
210
262
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
211
263
  """
212
264
 
213
- samples = self._structured_samples
214
- init_shape = jax.tree.leaves(samples)[0].shape
215
-
216
- flux = jax.vmap(
217
- lambda p: self.model.energy_flux(p, jnp.asarray([e_min]), jnp.asarray([e_max]))
218
- )(jax.tree.map(lambda x: x.ravel(), samples))
219
-
220
- flux = jax.tree.map(lambda x: x.reshape(init_shape), flux)
265
+ @jax.jit
266
+ @jnp.vectorize
267
+ def vectorized_flux(*pars):
268
+ parameters_pytree = jax.tree.unflatten(pytree_def, pars)
269
+ return self.model.energy_flux(
270
+ parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
271
+ )[0]
221
272
 
273
+ flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
274
+ flux = vectorized_flux(*flat_tree)
222
275
  conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
223
276
  value = flux * conversion_factor
224
277
 
225
278
  if register:
226
- self.inference_data.posterior[f"eflux_{e_min:.1f}_{e_max:.1f}"] = (
227
- ["chain", "draw"],
279
+ self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
280
+ list(self.inference_data.posterior.coords),
228
281
  value,
229
282
  )
230
283
 
@@ -244,8 +297,7 @@ class FitResult:
244
297
  Compute the luminosity of the source specifying its redshift. The luminosity is then added to
245
298
  the result parameters so covariance can be plotted.
246
299
 
247
- Parameters
248
- ----------
300
+ Parameters:
249
301
  e_min: The lower bound of the energy band.
250
302
  e_max: The upper bound of the energy band.
251
303
  redshift: The redshift of the source. It can be a distribution of redshifts.
@@ -258,24 +310,24 @@ class FitResult:
258
310
  if not observer_frame:
259
311
  raise NotImplementedError()
260
312
 
261
- samples = self._structured_samples
262
- init_shape = jax.tree.leaves(samples)[0].shape
263
-
264
- flux = jax.vmap(
265
- lambda p: self.model.energy_flux(
266
- p, jnp.asarray([e_min]) * (1 + redshift), jnp.asarray([e_max])
267
- )
268
- * (1 + redshift)
269
- )(jax.tree.map(lambda x: x.ravel(), samples))
270
-
271
- flux = jax.tree.map(
272
- lambda x: np.asarray(x.reshape(init_shape)) * (u.keV / u.cm**2 / u.s), flux
273
- )
313
+ @jax.jit
314
+ @jnp.vectorize
315
+ def vectorized_flux(*pars):
316
+ parameters_pytree = jax.tree.unflatten(pytree_def, pars)
317
+ return self.model.energy_flux(
318
+ parameters_pytree,
319
+ jnp.asarray([e_min]) * (1 + redshift),
320
+ jnp.asarray([e_max]) * (1 + redshift),
321
+ n_points=100,
322
+ )[0]
323
+
324
+ flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
325
+ flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
274
326
  value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
275
327
 
276
328
  if register:
277
329
  self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
278
- ["chain", "draw"],
330
+ list(self.inference_data.posterior.coords),
279
331
  value,
280
332
  )
281
333
 
@@ -285,8 +337,7 @@ class FitResult:
285
337
  """
286
338
  Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
287
339
 
288
- Parameters
289
- ----------
340
+ Parameters:
290
341
  name: The name of the chain.
291
342
  parameters_type: The parameters_type to include in the chain.
292
343
  """
@@ -424,14 +475,12 @@ class FitResult:
424
475
  $$ \text{Residual} = \frac{\text{Observed counts} - \text{Posterior counts}}
425
476
  {(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
426
477
 
427
- Parameters
428
- ----------
478
+ Parameters:
429
479
  percentile: The percentile of the posterior predictive distribution to plot.
430
480
  x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
431
481
  y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
432
482
 
433
483
  Returns:
434
- -------
435
484
  The matplotlib figure.
436
485
  """
437
486
 
@@ -668,10 +717,8 @@ class FitResult:
668
717
  """
669
718
  Plot the corner plot of the posterior distribution of the parameters_type. This method uses the ChainConsumer.
670
719
 
671
- Parameters
672
- ----------
720
+ Parameters:
673
721
  config: The configuration of the plot.
674
- parameters: The parameters to include in the plot using the following format: `blackbody_1_kT`.
675
722
  **kwargs: Additional arguments passed to ChainConsumer.plotter.plot. Some useful parameters are :
676
723
  - columns : list of parameters to plot.
677
724
  """
jaxspec/data/util.py CHANGED
@@ -1,8 +1,6 @@
1
- import importlib.resources
2
-
3
1
  from collections.abc import Mapping
4
2
  from pathlib import Path
5
- from typing import TypeVar
3
+ from typing import Literal, TypeVar
6
4
 
7
5
  import haiku as hk
8
6
  import jax
@@ -15,91 +13,118 @@ from numpyro import handlers
15
13
 
16
14
  from ..fit import CountForwardModel
17
15
  from ..model.abc import SpectralModel
16
+ from ..util.online_storage import table_manager
18
17
  from . import Instrument, ObsConfiguration, Observation
19
18
 
20
19
  K = TypeVar("K")
21
20
  V = TypeVar("V")
22
21
 
23
22
 
24
- def load_example_observations():
23
+ def load_example_pha(
24
+ source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
25
+ ) -> (Observation, list[Observation] | dict[str, Observation]):
25
26
  """
26
27
  Load some example observations from the package data.
28
+
29
+ Parameters:
30
+ source: The source to be loaded. Can be either "NGC7793_ULX4_PN" or "NGC7793_ULX4_ALL".
27
31
  """
28
32
 
29
- example_observations = {
30
- "PN": Observation.from_pha_file(
31
- str(importlib.resources.files("jaxspec") / "data/example_data/PN_spectrum_grp20.fits"),
32
- low_energy=0.3,
33
- high_energy=7.5,
34
- ),
35
- "MOS1": Observation.from_pha_file(
36
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS1_spectrum_grp.fits"),
37
- low_energy=0.3,
38
- high_energy=7,
39
- ),
40
- "MOS2": Observation.from_pha_file(
41
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS2_spectrum_grp.fits"),
42
- low_energy=0.3,
43
- high_energy=7,
44
- ),
45
- }
46
-
47
- return example_observations
48
-
49
-
50
- def load_example_instruments():
33
+ if source == "NGC7793_ULX4_PN":
34
+ return Observation.from_pha_file(
35
+ table_manager.fetch("example_data/NGC7793_ULX4/PN_spectrum_grp20.fits"),
36
+ bkg_path=table_manager.fetch("example_data/NGC7793_ULX4/PNbackground_spectrum.fits"),
37
+ )
38
+
39
+ elif source == "NGC7793_ULX4_ALL":
40
+ return {
41
+ "PN": Observation.from_pha_file(
42
+ table_manager.fetch("example_data/NGC7793_ULX4/PN_spectrum_grp20.fits"),
43
+ bkg_path=table_manager.fetch(
44
+ "example_data/NGC7793_ULX4/PNbackground_spectrum.fits"
45
+ ),
46
+ ),
47
+ "MOS1": Observation.from_pha_file(
48
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS1_spectrum_grp.fits"),
49
+ bkg_path=table_manager.fetch(
50
+ "example_data/NGC7793_ULX4/MOS1background_spectrum.fits"
51
+ ),
52
+ ),
53
+ "MOS2": Observation.from_pha_file(
54
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS2_spectrum_grp.fits"),
55
+ bkg_path=table_manager.fetch(
56
+ "example_data/NGC7793_ULX4/MOS2background_spectrum.fits"
57
+ ),
58
+ ),
59
+ }
60
+
61
+ else:
62
+ raise ValueError(f"{source} not recognized.")
63
+
64
+
65
+ def load_example_instruments(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"]):
51
66
  """
52
67
  Load some example instruments from the package data.
68
+
69
+ Parameters:
70
+ source: The source to be loaded. Can be either "NGC7793_ULX4_PN" or "NGC7793_ULX4_ALL".
71
+
53
72
  """
73
+ if source == "NGC7793_ULX4_PN":
74
+ return Instrument.from_ogip_file(
75
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.rmf"),
76
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.arf"),
77
+ )
78
+
79
+ elif source == "NGC7793_ULX4_ALL":
80
+ return {
81
+ "PN": Instrument.from_ogip_file(
82
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.rmf"),
83
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.arf"),
84
+ ),
85
+ "MOS1": Instrument.from_ogip_file(
86
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS1.rmf"),
87
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS1.arf"),
88
+ ),
89
+ "MOS2": Instrument.from_ogip_file(
90
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS2.rmf"),
91
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS2.arf"),
92
+ ),
93
+ }
54
94
 
55
- example_instruments = {
56
- "PN": Instrument.from_ogip_file(
57
- str(importlib.resources.files("jaxspec") / "data/example_data/PN.rmf"),
58
- str(importlib.resources.files("jaxspec") / "data/example_data/PN.arf"),
59
- ),
60
- "MOS1": Instrument.from_ogip_file(
61
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.rmf"),
62
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.arf"),
63
- ),
64
- "MOS2": Instrument.from_ogip_file(
65
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.rmf"),
66
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.arf"),
67
- ),
68
- }
69
-
70
- return example_instruments
71
-
72
-
73
- def load_example_foldings():
95
+ else:
96
+ raise ValueError(f"{source} not recognized.")
97
+
98
+
99
+ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"]):
74
100
  """
75
- Load some example instruments from the package data.
101
+ Load some example ObsConfigurations.
102
+
103
+ Parameters:
104
+ source: The source to be loaded. Can be either "NGC7793_ULX4_PN" or "NGC7793_ULX4_ALL".
76
105
  """
77
106
 
78
- example_instruments = load_example_instruments()
79
- example_observations = load_example_observations()
80
-
81
- example_foldings = {
82
- "PN": ObsConfiguration.from_instrument(
83
- example_instruments["PN"],
84
- example_observations["PN"],
85
- low_energy=0.3,
86
- high_energy=8.0,
87
- ),
88
- "MOS1": ObsConfiguration.from_instrument(
89
- example_instruments["MOS1"],
90
- example_observations["MOS1"],
91
- low_energy=0.3,
92
- high_energy=7,
93
- ),
94
- "MOS2": ObsConfiguration.from_instrument(
95
- example_instruments["MOS2"],
96
- example_observations["MOS2"],
97
- low_energy=0.3,
98
- high_energy=7,
99
- ),
100
- }
101
-
102
- return example_foldings
107
+ if source in "NGC7793_ULX4_PN":
108
+ instrument = load_example_instruments(source)
109
+ observation = load_example_pha(source)
110
+
111
+ return ObsConfiguration.from_instrument(
112
+ instrument, observation, low_energy=0.5, high_energy=8.0
113
+ )
114
+
115
+ elif source == "NGC7793_ULX4_ALL":
116
+ instruments_dict = load_example_instruments(source)
117
+ observations_dict = load_example_pha(source)
118
+
119
+ return {
120
+ key: ObsConfiguration.from_instrument(
121
+ instruments_dict[key], observations_dict[key], low_energy=0.5, high_energy=8.0
122
+ )
123
+ for key in instruments_dict.keys()
124
+ }
125
+
126
+ else:
127
+ raise ValueError(f"{source} not recognized.")
103
128
 
104
129
 
105
130
  def fakeit(
@@ -115,8 +140,7 @@ def fakeit(
115
140
  [XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
116
141
  exclusively by Poisson statistics.
117
142
 
118
- Parameters
119
- ----------
143
+ Parameters:
120
144
  instrument: The instrumental setup.
121
145
  model: The model to use.
122
146
  parameters: The parameters of the model.
@@ -177,8 +201,7 @@ def fakeit_for_multiple_parameters(
177
201
 
178
202
  TODO : avoid redundancy, better doc and type hints
179
203
 
180
- Parameters
181
- ----------
204
+ Parameters:
182
205
  instrument: The instrumental setup.
183
206
  model: The model to use.
184
207
  parameters: The parameters of the model.
@@ -219,12 +242,10 @@ def data_path_finder(pha_path: str) -> tuple[str | None, str | None, str | None]
219
242
  """
220
243
  Function which tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
221
244
 
222
- Parameters
223
- ----------
245
+ Parameters:
224
246
  pha_path: The PHA file path.
225
247
 
226
- Returns
227
- -------
248
+ Returns:
228
249
  arf_path: The ARF file path.
229
250
  rmf_path: The RMF file path.
230
251
  bkg_path: The BKG file path.