jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,10 @@
1
1
  import os
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
5
+
4
6
  from matplotlib import colors
7
+
5
8
  from .ogip import DataARF, DataRMF
6
9
 
7
10
 
@@ -25,7 +28,15 @@ class Instrument(xr.Dataset):
25
28
  )
26
29
 
27
30
  @classmethod
28
- def from_matrix(cls, redistribution_matrix, spectral_response, e_min_unfolded, e_max_unfolded, e_min_channel, e_max_channel):
31
+ def from_matrix(
32
+ cls,
33
+ redistribution_matrix,
34
+ spectral_response,
35
+ e_min_unfolded,
36
+ e_max_unfolded,
37
+ e_min_channel,
38
+ e_max_channel,
39
+ ):
29
40
  return cls(
30
41
  {
31
42
  "redistribution": (
@@ -65,7 +76,7 @@ class Instrument(xr.Dataset):
65
76
  )
66
77
 
67
78
  @classmethod
68
- def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike = None):
79
+ def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike | None = None):
69
80
  """
70
81
  Load the data from OGIP files.
71
82
 
@@ -80,37 +91,61 @@ class Instrument(xr.Dataset):
80
91
  specresp = DataARF.from_file(arf_path).specresp
81
92
 
82
93
  else:
83
- specresp = np.ones(rmf.energ_lo.shape)
94
+ specresp = rmf.matrix.sum(axis=0)
95
+ rmf.sparse_matrix /= specresp
84
96
 
85
- return cls.from_matrix(rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max)
97
+ return cls.from_matrix(
98
+ rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
99
+ )
86
100
 
87
- def plot_redistribution(self, **kwargs):
101
+ def plot_redistribution(
102
+ self,
103
+ xscale: str = "log",
104
+ yscale: str = "log",
105
+ cmap=None,
106
+ vmin: float = 1e-6,
107
+ vmax: float = 1e0,
108
+ add_labels: bool = True,
109
+ **kwargs,
110
+ ):
88
111
  """
89
112
  Plot the redistribution probability matrix
90
113
 
91
114
  Parameters:
115
+ xscale : The scale of the x-axis.
116
+ yscale : The scale of the y-axis.
117
+ cmap : The colormap to use.
118
+ vmin : The minimum value for the colormap.
119
+ vmax : The maximum value for the colormap.
120
+ add_labels : Whether to add labels to the plot.
92
121
  **kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.plot.pcolormesh.html#xarray.plot.pcolormesh
93
122
  """
123
+
94
124
  import cmasher as cmr
95
125
 
96
126
  return xr.plot.pcolormesh(
97
127
  self.redistribution,
98
128
  x="e_max_unfolded",
99
129
  y="e_max_channel",
100
- xscale="log",
101
- yscale="log",
102
- cmap=cmr.ember_r,
103
- norm=colors.LogNorm(vmin=1e-6, vmax=1),
104
- add_labels=True,
130
+ xscale=xscale,
131
+ yscale=yscale,
132
+ cmap=cmr.ember_r if cmap is None else cmap,
133
+ norm=colors.LogNorm(vmin=vmin, vmax=vmax),
134
+ add_labels=add_labels,
105
135
  **kwargs,
106
136
  )
107
137
 
108
- def plot_area(self, **kwargs):
138
+ def plot_area(self, xscale: str = "log", yscale: str = "log", where: str = "post", **kwargs):
109
139
  """
110
140
  Plot the effective area
111
141
 
112
142
  Parameters:
143
+ xscale : The scale of the x-axis.
144
+ yscale : The scale of the y-axis.
145
+ where : The position of the steps.
113
146
  **kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.DataArray.plot.line.html#xarray.DataArray.plot.line
114
147
  """
115
148
 
116
- return self.area.plot.step(x="e_min_unfolded", xscale="log", yscale="log", where="post", **kwargs)
149
+ return self.area.plot.step(
150
+ x="e_min_unfolded", xscale=xscale, yscale=yscale, where=where, **kwargs
151
+ )
jaxspec/data/obsconf.py CHANGED
@@ -157,8 +157,10 @@ class ObsConfiguration(xr.Dataset):
157
157
 
158
158
  if observation.folded_background is not None:
159
159
  folded_background = observation.folded_background.data[row_idx]
160
+ folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
160
161
  else:
161
162
  folded_background = np.zeros_like(folded_counts)
163
+ folded_background_unscaled = np.zeros_like(folded_counts)
162
164
 
163
165
  data_dict = {
164
166
  "transfer_matrix": (
@@ -200,6 +202,14 @@ class ObsConfiguration(xr.Dataset):
200
202
  "unit": "photons",
201
203
  },
202
204
  ),
205
+ "folded_background_unscaled": (
206
+ ["folded_channel"],
207
+ folded_background_unscaled,
208
+ {
209
+ "description": "To be done",
210
+ "unit": "photons",
211
+ },
212
+ ),
203
213
  }
204
214
 
205
215
  return cls(
@@ -234,8 +244,8 @@ class ObsConfiguration(xr.Dataset):
234
244
  cls,
235
245
  instrument: Instrument,
236
246
  exposure: float,
237
- low_energy: float = 1e-20,
238
- high_energy: float = 1e20,
247
+ low_energy: float = 1e-300,
248
+ high_energy: float = 1e300,
239
249
  ):
240
250
  """
241
251
  Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
@@ -46,14 +46,16 @@ class Observation(xr.Dataset):
46
46
  quality,
47
47
  exposure,
48
48
  background=None,
49
+ background_unscaled=None,
49
50
  backratio=1.0,
50
51
  attributes: dict | None = None,
51
52
  ):
52
53
  if attributes is None:
53
54
  attributes = {}
54
55
 
55
- if background is None:
56
+ if background is None or background_unscaled is None:
56
57
  background = np.zeros_like(counts, dtype=np.int64)
58
+ background_unscaled = np.zeros_like(counts, dtype=np.int64)
57
59
 
58
60
  data_dict = {
59
61
  "counts": (
@@ -92,9 +94,14 @@ class Observation(xr.Dataset):
92
94
  np.asarray(background, dtype=np.int64),
93
95
  {"description": "Background counts", "unit": "photons"},
94
96
  ),
97
+ "folded_background_unscaled": (
98
+ ["folded_channel"],
99
+ np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
100
+ {"description": "Background counts", "unit": "photons"},
101
+ ),
95
102
  "folded_background": (
96
103
  ["folded_channel"],
97
- np.asarray(np.ma.filled(grouping @ background), dtype=np.int64),
104
+ np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
98
105
  {"description": "Background counts", "unit": "photons"},
99
106
  ),
100
107
  }
@@ -128,14 +135,20 @@ class Observation(xr.Dataset):
128
135
  else:
129
136
  backratio = np.ones_like(pha.counts)
130
137
 
138
+ if (bkg is not None) and ("NET" in pha.flags):
139
+ counts = pha.counts + bkg.counts * backratio
140
+ else:
141
+ counts = pha.counts
142
+
131
143
  return cls.from_matrix(
132
- pha.counts,
144
+ counts,
133
145
  pha.grouping,
134
146
  pha.channel,
135
147
  pha.quality,
136
148
  pha.exposure,
137
149
  backratio=backratio,
138
- background=bkg.counts if bkg is not None else None,
150
+ background=bkg.counts * backratio if bkg is not None else None,
151
+ background_unscaled=bkg.counts if bkg is not None else None,
139
152
  attributes=metadata,
140
153
  )
141
154
 
jaxspec/data/ogip.py CHANGED
@@ -1,9 +1,11 @@
1
- import numpy as np
2
1
  import os
2
+
3
3
  import astropy.units as u
4
+ import numpy as np
4
5
  import sparse
5
- from astropy.table import QTable
6
+
6
7
  from astropy.io import fits
8
+ from astropy.table import QTable
7
9
 
8
10
 
9
11
  class DataPHA:
@@ -25,6 +27,7 @@ class DataPHA:
25
27
  ancrfile=None,
26
28
  backscal=1.0,
27
29
  areascal=1.0,
30
+ flags=None,
28
31
  ):
29
32
  self.channel = np.asarray(channel, dtype=int)
30
33
  self.counts = np.asarray(counts, dtype=int)
@@ -36,6 +39,7 @@ class DataPHA:
36
39
  self.ancrfile = ancrfile
37
40
  self.backscal = np.asarray(backscal, dtype=float)
38
41
  self.areascal = np.asarray(areascal, dtype=float)
42
+ self.flags = flags
39
43
 
40
44
  if grouping is not None:
41
45
  # Indices array of the beginning of each group
@@ -55,7 +59,9 @@ class DataPHA:
55
59
  data.append(True)
56
60
 
57
61
  # Create a COO sparse matrix
58
- grp_matrix = sparse.COO((data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0)
62
+ grp_matrix = sparse.COO(
63
+ (data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0
64
+ )
59
65
 
60
66
  else:
61
67
  # Identity matrix case, use sparse for efficiency
@@ -74,12 +80,7 @@ class DataPHA:
74
80
 
75
81
  data = QTable.read(pha_file, "SPECTRUM")
76
82
  header = fits.getheader(pha_file, "SPECTRUM")
77
-
78
- if header.get("HDUCLAS2") == "NET":
79
- raise ValueError(
80
- f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
81
- f"Please open an issue if this is required."
82
- )
83
+ flags = []
83
84
 
84
85
  if header.get("HDUCLAS3") == "RATE":
85
86
  raise ValueError(
@@ -114,6 +115,8 @@ class DataPHA:
114
115
  else:
115
116
  raise ValueError("No BACKSCAL found in the PHA file.")
116
117
 
118
+ backscal = np.where(backscal == 0, 1.0, backscal)
119
+
117
120
  if "AREASCAL" in header:
118
121
  areascal = header["AREASCAL"]
119
122
  elif "AREASCAL" in data.colnames:
@@ -121,8 +124,9 @@ class DataPHA:
121
124
  else:
122
125
  raise ValueError("No AREASCAL found in the PHA file.")
123
126
 
124
- # Grouping and quality parameters are in binned PHA dataset
125
- # Backfile, respfile and ancrfile are in primary header
127
+ if header.get("HDUCLAS2") == "NET":
128
+ flags.append("NET")
129
+
126
130
  kwargs = {
127
131
  "grouping": grouping,
128
132
  "quality": quality,
@@ -131,6 +135,7 @@ class DataPHA:
131
135
  "ancrfile": header.get("ANCRFILE"),
132
136
  "backscal": backscal,
133
137
  "areascal": areascal,
138
+ "flags": flags,
134
139
  }
135
140
 
136
141
  return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
@@ -176,7 +181,19 @@ class DataRMF:
176
181
  * [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
177
182
  """
178
183
 
179
- def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max, low_threshold=0.0):
184
+ def __init__(
185
+ self,
186
+ energ_lo,
187
+ energ_hi,
188
+ n_grp,
189
+ f_chan,
190
+ n_chan,
191
+ matrix,
192
+ channel,
193
+ e_min,
194
+ e_max,
195
+ low_threshold=0.0,
196
+ ):
180
197
  # RMF stuff
181
198
  self.energ_lo = energ_lo # "Entry" energies
182
199
  self.energ_hi = energ_hi # "Entry" energies
@@ -229,7 +246,9 @@ class DataRMF:
229
246
  idxs = data > low_threshold
230
247
 
231
248
  # Create a COO sparse matrix and then convert to CSR for efficiency
232
- coo = sparse.COO([rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel)))
249
+ coo = sparse.COO(
250
+ [rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel))
251
+ )
233
252
  self.sparse_matrix = coo.T # .tocsr()
234
253
 
235
254
  @property
jaxspec/data/util.py CHANGED
@@ -1,17 +1,16 @@
1
1
  from collections.abc import Mapping
2
2
  from pathlib import Path
3
- from typing import Literal, TypeVar
3
+ from typing import TYPE_CHECKING, Literal, TypeVar
4
4
 
5
- import haiku as hk
6
5
  import jax
6
+ import jax.numpy as jnp
7
7
  import numpy as np
8
8
  import numpyro
9
9
 
10
10
  from astropy.io import fits
11
- from numpy.typing import ArrayLike
11
+ from jax.experimental.sparse import BCOO
12
12
  from numpyro import handlers
13
13
 
14
- from .._fit._build_model import CountForwardModel
15
14
  from ..model.abc import SpectralModel
16
15
  from ..util.online_storage import table_manager
17
16
  from . import Instrument, ObsConfiguration, Observation
@@ -19,6 +18,10 @@ from . import Instrument, ObsConfiguration, Observation
19
18
  K = TypeVar("K")
20
19
  V = TypeVar("V")
21
20
 
21
+ if TYPE_CHECKING:
22
+ from ..data import ObsConfiguration
23
+ from ..model.abc import SpectralModel
24
+
22
25
 
23
26
  def load_example_pha(
24
27
  source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
@@ -127,69 +130,40 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
127
130
  raise ValueError(f"{source} not recognized.")
128
131
 
129
132
 
130
- def fakeit(
131
- instrument: ObsConfiguration | list[ObsConfiguration],
132
- model: SpectralModel,
133
- parameters: Mapping[K, V],
134
- rng_key: int = 0,
135
- sparsify_matrix: bool = False,
136
- ) -> ArrayLike | list[ArrayLike]:
137
- """
138
- Convenience function to simulate a spectrum from a given model and a set of parameters.
139
- It requires an instrumental setup, and unlike in
140
- [XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
141
- exclusively by Poisson statistics.
142
-
143
- Parameters:
144
- instrument: The instrumental setup.
145
- model: The model to use.
146
- parameters: The parameters of the model.
147
- rng_key: The random number generator seed.
148
- sparsify_matrix: Whether to sparsify the matrix or not.
149
- """
150
-
151
- instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
152
- fakeits = []
133
+ def forward_model_with_multiple_inputs(
134
+ model: "SpectralModel",
135
+ parameters,
136
+ obs_configuration: "ObsConfiguration",
137
+ sparse=False,
138
+ ):
139
+ energies = np.asarray(obs_configuration.in_energies)
140
+ parameter_dims = next(iter(parameters.values())).shape
153
141
 
154
- for i, instrument in enumerate(instruments):
155
- transformed_model = hk.without_apply_rng(
156
- hk.transform(
157
- lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
158
- )
159
- )
142
+ def flux_func(p):
143
+ return model.photon_flux(p, *energies)
160
144
 
161
- def obs_model(p):
162
- return transformed_model.apply(None, p)
145
+ for _ in parameter_dims:
146
+ flux_func = jax.vmap(flux_func)
163
147
 
164
- with handlers.seed(rng_seed=rng_key):
165
- counts = numpyro.sample(
166
- f"likelihood_obs_{i}",
167
- numpyro.distributions.Poisson(obs_model(parameters)),
168
- )
148
+ flux_func = jax.jit(flux_func)
169
149
 
170
- """
171
- pha = DataPHA(
172
- instrument.rmf.channel,
173
- np.array(counts, dtype=int)*u.ct,
174
- instrument.exposure,
175
- grouping=instrument.grouping)
176
-
177
- observation = Observation(
178
- pha=pha,
179
- arf=instrument.arf,
180
- rmf=instrument.rmf,
181
- low_energy=instrument.low_energy,
182
- high_energy=instrument.high_energy
150
+ if sparse:
151
+ # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
152
+ transfer_matrix = BCOO.from_scipy_sparse(
153
+ obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
183
154
  )
184
- """
185
155
 
186
- fakeits.append(np.array(counts, dtype=int))
156
+ else:
157
+ transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
187
158
 
188
- return fakeits[0] if len(fakeits) == 1 else fakeits
159
+ expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
160
+
161
+ # The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
162
+ return jnp.clip(expected_counts, a_min=1e-6)
189
163
 
190
164
 
191
165
  def fakeit_for_multiple_parameters(
192
- instrument: ObsConfiguration | list[ObsConfiguration],
166
+ obsconfs: ObsConfiguration | list[ObsConfiguration],
193
167
  model: SpectralModel,
194
168
  parameters: Mapping[K, V],
195
169
  rng_key: int = 0,
@@ -198,11 +172,32 @@ def fakeit_for_multiple_parameters(
198
172
  ):
199
173
  """
200
174
  Convenience function to simulate multiple spectra from a given model and a set of parameters.
175
+ This is supposed to be somewhat optimized and can handle multiple parameters at once without blowing
176
+ up the memory. The parameters should be passed as a dictionary with the parameter name as the key and
177
+ the parameter values as the values, the value can be a scalar or a nd-array.
178
+
179
+ # Example:
180
+
181
+ ``` python
182
+ from jaxspec.data.util import fakeit_for_multiple_parameters
183
+ from numpy.random import default_rng
201
184
 
202
- TODO : avoid redundancy, better doc and type hints
185
+ rng = default_rng(42)
186
+ size = (10, 30)
187
+
188
+ parameters = {
189
+ "tbabs_1_nh": rng.uniform(0.1, 0.4, size=size),
190
+ "powerlaw_1_alpha": rng.uniform(1, 3, size=size),
191
+ "powerlaw_1_norm": rng.exponential(10 ** (-0.5), size=size),
192
+ "blackbodyrad_1_kT": rng.uniform(0.1, 3.0, size=size),
193
+ "blackbodyrad_1_norm": rng.exponential(10 ** (-3), size=size)
194
+ }
195
+
196
+ spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
197
+ ```
203
198
 
204
199
  Parameters:
205
- instrument: The instrumental setup.
200
+ obsconfs: The observational setup(s).
206
201
  model: The model to use.
207
202
  parameters: The parameters of the model.
208
203
  rng_key: The random number generator seed.
@@ -210,28 +205,23 @@ def fakeit_for_multiple_parameters(
210
205
  sparsify_matrix: Whether to sparsify the matrix or not.
211
206
  """
212
207
 
213
- instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
208
+ obsconf_list = [obsconfs] if isinstance(obsconfs, ObsConfiguration) else obsconfs
214
209
  fakeits = []
215
210
 
216
- for i, obs in enumerate(instruments):
217
- transformed_model = hk.without_apply_rng(
218
- hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparsify_matrix)(par))
211
+ for i, obsconf in enumerate(obsconf_list):
212
+ countrate = forward_model_with_multiple_inputs(
213
+ model, parameters, obsconf, sparse=sparsify_matrix
219
214
  )
220
215
 
221
- @jax.jit
222
- @jax.vmap
223
- def obs_model(p):
224
- return transformed_model.apply(None, p)
225
-
226
216
  if apply_stat:
227
217
  with handlers.seed(rng_seed=rng_key):
228
218
  spectrum = numpyro.sample(
229
219
  f"likelihood_obs_{i}",
230
- numpyro.distributions.Poisson(obs_model(parameters)),
220
+ numpyro.distributions.Poisson(countrate),
231
221
  )
232
222
 
233
223
  else:
234
- spectrum = obs_model(parameters)
224
+ spectrum = countrate
235
225
 
236
226
  fakeits.append(spectrum)
237
227