jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,10 @@
1
1
  import os
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
5
+
4
6
  from matplotlib import colors
7
+
5
8
  from .ogip import DataARF, DataRMF
6
9
 
7
10
 
@@ -25,7 +28,15 @@ class Instrument(xr.Dataset):
25
28
  )
26
29
 
27
30
  @classmethod
28
- def from_matrix(cls, redistribution_matrix, spectral_response, e_min_unfolded, e_max_unfolded, e_min_channel, e_max_channel):
31
+ def from_matrix(
32
+ cls,
33
+ redistribution_matrix,
34
+ spectral_response,
35
+ e_min_unfolded,
36
+ e_max_unfolded,
37
+ e_min_channel,
38
+ e_max_channel,
39
+ ):
29
40
  return cls(
30
41
  {
31
42
  "redistribution": (
@@ -65,7 +76,7 @@ class Instrument(xr.Dataset):
65
76
  )
66
77
 
67
78
  @classmethod
68
- def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike = None):
79
+ def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike | None = None):
69
80
  """
70
81
  Load the data from OGIP files.
71
82
 
@@ -80,37 +91,61 @@ class Instrument(xr.Dataset):
80
91
  specresp = DataARF.from_file(arf_path).specresp
81
92
 
82
93
  else:
83
- specresp = np.ones(rmf.energ_lo.shape)
94
+ specresp = rmf.matrix.sum(axis=0)
95
+ rmf.sparse_matrix /= specresp
84
96
 
85
- return cls.from_matrix(rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max)
97
+ return cls.from_matrix(
98
+ rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
99
+ )
86
100
 
87
- def plot_redistribution(self, **kwargs):
101
+ def plot_redistribution(
102
+ self,
103
+ xscale: str = "log",
104
+ yscale: str = "log",
105
+ cmap=None,
106
+ vmin: float = 1e-6,
107
+ vmax: float = 1e0,
108
+ add_labels: bool = True,
109
+ **kwargs,
110
+ ):
88
111
  """
89
112
  Plot the redistribution probability matrix
90
113
 
91
114
  Parameters:
115
+ xscale : The scale of the x-axis.
116
+ yscale : The scale of the y-axis.
117
+ cmap : The colormap to use.
118
+ vmin : The minimum value for the colormap.
119
+ vmax : The maximum value for the colormap.
120
+ add_labels : Whether to add labels to the plot.
92
121
  **kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.plot.pcolormesh.html#xarray.plot.pcolormesh
93
122
  """
123
+
94
124
  import cmasher as cmr
95
125
 
96
126
  return xr.plot.pcolormesh(
97
127
  self.redistribution,
98
128
  x="e_max_unfolded",
99
129
  y="e_max_channel",
100
- xscale="log",
101
- yscale="log",
102
- cmap=cmr.ember_r,
103
- norm=colors.LogNorm(vmin=1e-6, vmax=1),
104
- add_labels=True,
130
+ xscale=xscale,
131
+ yscale=yscale,
132
+ cmap=cmr.ember_r if cmap is None else cmap,
133
+ norm=colors.LogNorm(vmin=vmin, vmax=vmax),
134
+ add_labels=add_labels,
105
135
  **kwargs,
106
136
  )
107
137
 
108
- def plot_area(self, **kwargs):
138
+ def plot_area(self, xscale: str = "log", yscale: str = "log", where: str = "post", **kwargs):
109
139
  """
110
140
  Plot the effective area
111
141
 
112
142
  Parameters:
143
+ xscale : The scale of the x-axis.
144
+ yscale : The scale of the y-axis.
145
+ where : The position of the steps.
113
146
  **kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.DataArray.plot.line.html#xarray.DataArray.plot.line
114
147
  """
115
148
 
116
- return self.area.plot.step(x="e_min_unfolded", xscale="log", yscale="log", where="post", **kwargs)
149
+ return self.area.plot.step(
150
+ x="e_min_unfolded", xscale=xscale, yscale=yscale, where=where, **kwargs
151
+ )
jaxspec/data/obsconf.py CHANGED
@@ -157,8 +157,10 @@ class ObsConfiguration(xr.Dataset):
157
157
 
158
158
  if observation.folded_background is not None:
159
159
  folded_background = observation.folded_background.data[row_idx]
160
+ folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
160
161
  else:
161
162
  folded_background = np.zeros_like(folded_counts)
163
+ folded_background_unscaled = np.zeros_like(folded_counts)
162
164
 
163
165
  data_dict = {
164
166
  "transfer_matrix": (
@@ -200,6 +202,14 @@ class ObsConfiguration(xr.Dataset):
200
202
  "unit": "photons",
201
203
  },
202
204
  ),
205
+ "folded_background_unscaled": (
206
+ ["folded_channel"],
207
+ folded_background_unscaled,
208
+ {
209
+ "description": "To be done",
210
+ "unit": "photons",
211
+ },
212
+ ),
203
213
  }
204
214
 
205
215
  return cls(
@@ -234,8 +244,8 @@ class ObsConfiguration(xr.Dataset):
234
244
  cls,
235
245
  instrument: Instrument,
236
246
  exposure: float,
237
- low_energy: float = 1e-20,
238
- high_energy: float = 1e20,
247
+ low_energy: float = 1e-300,
248
+ high_energy: float = 1e300,
239
249
  ):
240
250
  """
241
251
  Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
@@ -1,5 +1,6 @@
1
1
  import numpy as np
2
2
  import xarray as xr
3
+
3
4
  from .ogip import DataPHA
4
5
 
5
6
 
@@ -23,7 +24,16 @@ class Observation(xr.Dataset):
23
24
  folded_background: xr.DataArray
24
25
  """The background counts, after grouping"""
25
26
 
26
- __slots__ = ("grouping", "channel", "quality", "exposure", "background", "folded_background", "counts", "folded_counts")
27
+ __slots__ = (
28
+ "grouping",
29
+ "channel",
30
+ "quality",
31
+ "exposure",
32
+ "background",
33
+ "folded_background",
34
+ "counts",
35
+ "folded_counts",
36
+ )
27
37
 
28
38
  _default_attributes = {"description": "X-ray observation dataset"}
29
39
 
@@ -36,17 +46,23 @@ class Observation(xr.Dataset):
36
46
  quality,
37
47
  exposure,
38
48
  background=None,
49
+ background_unscaled=None,
39
50
  backratio=1.0,
40
51
  attributes: dict | None = None,
41
52
  ):
42
53
  if attributes is None:
43
54
  attributes = {}
44
55
 
45
- if background is None:
56
+ if background is None or background_unscaled is None:
46
57
  background = np.zeros_like(counts, dtype=np.int64)
58
+ background_unscaled = np.zeros_like(counts, dtype=np.int64)
47
59
 
48
60
  data_dict = {
49
- "counts": (["instrument_channel"], np.asarray(counts, dtype=np.int64), {"description": "Counts", "unit": "photons"}),
61
+ "counts": (
62
+ ["instrument_channel"],
63
+ np.asarray(counts, dtype=np.int64),
64
+ {"description": "Counts", "unit": "photons"},
65
+ ),
50
66
  "folded_counts": (
51
67
  ["folded_channel"],
52
68
  np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64),
@@ -57,7 +73,11 @@ class Observation(xr.Dataset):
57
73
  grouping,
58
74
  {"description": "Grouping matrix."},
59
75
  ),
60
- "quality": (["instrument_channel"], np.asarray(quality, dtype=np.int64), {"description": "Quality flag."}),
76
+ "quality": (
77
+ ["instrument_channel"],
78
+ np.asarray(quality, dtype=np.int64),
79
+ {"description": "Quality flag."},
80
+ ),
61
81
  "exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}),
62
82
  "backratio": (
63
83
  ["instrument_channel"],
@@ -74,9 +94,14 @@ class Observation(xr.Dataset):
74
94
  np.asarray(background, dtype=np.int64),
75
95
  {"description": "Background counts", "unit": "photons"},
76
96
  ),
97
+ "folded_background_unscaled": (
98
+ ["folded_channel"],
99
+ np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
100
+ {"description": "Background counts", "unit": "photons"},
101
+ ),
77
102
  "folded_background": (
78
103
  ["folded_channel"],
79
- np.asarray(np.ma.filled(grouping @ background), dtype=np.int64),
104
+ np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
80
105
  {"description": "Background counts", "unit": "photons"},
81
106
  ),
82
107
  }
@@ -84,36 +109,59 @@ class Observation(xr.Dataset):
84
109
  return cls(
85
110
  data_dict,
86
111
  coords={
87
- "channel": (["instrument_channel"], np.asarray(channel, dtype=np.int64), {"description": "Channel number"}),
112
+ "channel": (
113
+ ["instrument_channel"],
114
+ np.asarray(channel, dtype=np.int64),
115
+ {"description": "Channel number"},
116
+ ),
88
117
  "grouped_channel": (
89
118
  ["folded_channel"],
90
119
  np.arange(len(grouping @ counts), dtype=np.int64),
91
120
  {"description": "Channel number"},
92
121
  ),
93
122
  },
94
- attrs=cls._default_attributes if attributes is None else attributes | cls._default_attributes,
123
+ attrs=cls._default_attributes
124
+ if attributes is None
125
+ else attributes | cls._default_attributes,
95
126
  )
96
127
 
97
128
  @classmethod
98
129
  def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
99
130
  if bkg is not None:
100
- backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal))
131
+ backratio = np.nan_to_num(
132
+ (pha.backscal * pha.exposure * pha.areascal)
133
+ / (bkg.backscal * bkg.exposure * bkg.areascal)
134
+ )
101
135
  else:
102
136
  backratio = np.ones_like(pha.counts)
103
137
 
138
+ if (bkg is not None) and ("NET" in pha.flags):
139
+ counts = pha.counts + bkg.counts * backratio
140
+ else:
141
+ counts = pha.counts
142
+
104
143
  return cls.from_matrix(
105
- pha.counts,
144
+ counts,
106
145
  pha.grouping,
107
146
  pha.channel,
108
147
  pha.quality,
109
148
  pha.exposure,
110
149
  backratio=backratio,
111
- background=bkg.counts if bkg is not None else None,
150
+ background=bkg.counts * backratio if bkg is not None else None,
151
+ background_unscaled=bkg.counts if bkg is not None else None,
112
152
  attributes=metadata,
113
153
  )
114
154
 
115
155
  @classmethod
116
156
  def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
157
+ """
158
+ Build an observation from a PHA file
159
+
160
+ Parameters:
161
+ pha_path : Path to the PHA file
162
+ bkg_path : Path to the background file
163
+ metadata : Additional metadata to add to the observation
164
+ """
117
165
  from .util import data_path_finder
118
166
 
119
167
  arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
@@ -155,7 +203,16 @@ class Observation(xr.Dataset):
155
203
 
156
204
  fig = plt.figure(figsize=(6, 6))
157
205
  gs = fig.add_gridspec(
158
- 2, 2, width_ratios=(4, 1), height_ratios=(1, 4), left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.05, hspace=0.05
206
+ 2,
207
+ 2,
208
+ width_ratios=(4, 1),
209
+ height_ratios=(1, 4),
210
+ left=0.1,
211
+ right=0.9,
212
+ bottom=0.1,
213
+ top=0.9,
214
+ wspace=0.05,
215
+ hspace=0.05,
159
216
  )
160
217
  ax = fig.add_subplot(gs[1, 0])
161
218
  ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
jaxspec/data/ogip.py CHANGED
@@ -1,9 +1,11 @@
1
- import numpy as np
2
1
  import os
2
+
3
3
  import astropy.units as u
4
+ import numpy as np
4
5
  import sparse
5
- from astropy.table import QTable
6
+
6
7
  from astropy.io import fits
8
+ from astropy.table import QTable
7
9
 
8
10
 
9
11
  class DataPHA:
@@ -25,6 +27,7 @@ class DataPHA:
25
27
  ancrfile=None,
26
28
  backscal=1.0,
27
29
  areascal=1.0,
30
+ flags=None,
28
31
  ):
29
32
  self.channel = np.asarray(channel, dtype=int)
30
33
  self.counts = np.asarray(counts, dtype=int)
@@ -36,6 +39,7 @@ class DataPHA:
36
39
  self.ancrfile = ancrfile
37
40
  self.backscal = np.asarray(backscal, dtype=float)
38
41
  self.areascal = np.asarray(areascal, dtype=float)
42
+ self.flags = flags
39
43
 
40
44
  if grouping is not None:
41
45
  # Indices array of the beginning of each group
@@ -55,7 +59,9 @@ class DataPHA:
55
59
  data.append(True)
56
60
 
57
61
  # Create a COO sparse matrix
58
- grp_matrix = sparse.COO((data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0)
62
+ grp_matrix = sparse.COO(
63
+ (data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0
64
+ )
59
65
 
60
66
  else:
61
67
  # Identity matrix case, use sparse for efficiency
@@ -74,12 +80,7 @@ class DataPHA:
74
80
 
75
81
  data = QTable.read(pha_file, "SPECTRUM")
76
82
  header = fits.getheader(pha_file, "SPECTRUM")
77
-
78
- if header.get("HDUCLAS2") == "NET":
79
- raise ValueError(
80
- f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
81
- f"Please open an issue if this is required."
82
- )
83
+ flags = []
83
84
 
84
85
  if header.get("HDUCLAS3") == "RATE":
85
86
  raise ValueError(
@@ -114,6 +115,8 @@ class DataPHA:
114
115
  else:
115
116
  raise ValueError("No BACKSCAL found in the PHA file.")
116
117
 
118
+ backscal = np.where(backscal == 0, 1.0, backscal)
119
+
117
120
  if "AREASCAL" in header:
118
121
  areascal = header["AREASCAL"]
119
122
  elif "AREASCAL" in data.colnames:
@@ -121,8 +124,9 @@ class DataPHA:
121
124
  else:
122
125
  raise ValueError("No AREASCAL found in the PHA file.")
123
126
 
124
- # Grouping and quality parameters are in binned PHA dataset
125
- # Backfile, respfile and ancrfile are in primary header
127
+ if header.get("HDUCLAS2") == "NET":
128
+ flags.append("NET")
129
+
126
130
  kwargs = {
127
131
  "grouping": grouping,
128
132
  "quality": quality,
@@ -131,6 +135,7 @@ class DataPHA:
131
135
  "ancrfile": header.get("ANCRFILE"),
132
136
  "backscal": backscal,
133
137
  "areascal": areascal,
138
+ "flags": flags,
134
139
  }
135
140
 
136
141
  return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
@@ -176,7 +181,19 @@ class DataRMF:
176
181
  * [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
177
182
  """
178
183
 
179
- def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max, low_threshold=0.0):
184
+ def __init__(
185
+ self,
186
+ energ_lo,
187
+ energ_hi,
188
+ n_grp,
189
+ f_chan,
190
+ n_chan,
191
+ matrix,
192
+ channel,
193
+ e_min,
194
+ e_max,
195
+ low_threshold=0.0,
196
+ ):
180
197
  # RMF stuff
181
198
  self.energ_lo = energ_lo # "Entry" energies
182
199
  self.energ_hi = energ_hi # "Entry" energies
@@ -229,7 +246,9 @@ class DataRMF:
229
246
  idxs = data > low_threshold
230
247
 
231
248
  # Create a COO sparse matrix and then convert to CSR for efficiency
232
- coo = sparse.COO([rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel)))
249
+ coo = sparse.COO(
250
+ [rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel))
251
+ )
233
252
  self.sparse_matrix = coo.T # .tocsr()
234
253
 
235
254
  @property
jaxspec/data/util.py CHANGED
@@ -2,16 +2,13 @@ from collections.abc import Mapping
2
2
  from pathlib import Path
3
3
  from typing import Literal, TypeVar
4
4
 
5
- import haiku as hk
6
5
  import jax
7
- import numpy as np
8
6
  import numpyro
9
7
 
10
8
  from astropy.io import fits
11
- from numpy.typing import ArrayLike
12
9
  from numpyro import handlers
13
10
 
14
- from ..fit import CountForwardModel
11
+ from .._fit._build_model import forward_model
15
12
  from ..model.abc import SpectralModel
16
13
  from ..util.online_storage import table_manager
17
14
  from . import Instrument, ObsConfiguration, Observation
@@ -127,67 +124,6 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
127
124
  raise ValueError(f"{source} not recognized.")
128
125
 
129
126
 
130
- def fakeit(
131
- instrument: ObsConfiguration | list[ObsConfiguration],
132
- model: SpectralModel,
133
- parameters: Mapping[K, V],
134
- rng_key: int = 0,
135
- sparsify_matrix: bool = False,
136
- ) -> ArrayLike | list[ArrayLike]:
137
- """
138
- Convenience function to simulate a spectrum from a given model and a set of parameters.
139
- It requires an instrumental setup, and unlike in
140
- [XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
141
- exclusively by Poisson statistics.
142
-
143
- Parameters:
144
- instrument: The instrumental setup.
145
- model: The model to use.
146
- parameters: The parameters of the model.
147
- rng_key: The random number generator seed.
148
- sparsify_matrix: Whether to sparsify the matrix or not.
149
- """
150
-
151
- instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
152
- fakeits = []
153
-
154
- for i, instrument in enumerate(instruments):
155
- transformed_model = hk.without_apply_rng(
156
- hk.transform(
157
- lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
158
- )
159
- )
160
-
161
- def obs_model(p):
162
- return transformed_model.apply(None, p)
163
-
164
- with handlers.seed(rng_seed=rng_key):
165
- counts = numpyro.sample(
166
- f"likelihood_obs_{i}",
167
- numpyro.distributions.Poisson(obs_model(parameters)),
168
- )
169
-
170
- """
171
- pha = DataPHA(
172
- instrument.rmf.channel,
173
- np.array(counts, dtype=int)*u.ct,
174
- instrument.exposure,
175
- grouping=instrument.grouping)
176
-
177
- observation = Observation(
178
- pha=pha,
179
- arf=instrument.arf,
180
- rmf=instrument.rmf,
181
- low_energy=instrument.low_energy,
182
- high_energy=instrument.high_energy
183
- )
184
- """
185
-
186
- fakeits.append(np.array(counts, dtype=int))
187
-
188
- return fakeits[0] if len(fakeits) == 1 else fakeits
189
-
190
-
191
127
  def fakeit_for_multiple_parameters(
192
128
  instrument: ObsConfiguration | list[ObsConfiguration],
193
129
  model: SpectralModel,
@@ -199,7 +135,6 @@ def fakeit_for_multiple_parameters(
199
135
  """
200
136
  Convenience function to simulate multiple spectra from a given model and a set of parameters.
201
137
 
202
- TODO : avoid redundancy, better doc and type hints
203
138
 
204
139
  Parameters:
205
140
  instrument: The instrumental setup.
@@ -214,24 +149,19 @@ def fakeit_for_multiple_parameters(
214
149
  fakeits = []
215
150
 
216
151
  for i, obs in enumerate(instruments):
217
- transformed_model = hk.without_apply_rng(
218
- hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparsify_matrix)(par))
152
+ countrate = jax.vmap(lambda p: forward_model(model, p, instrument, sparse=sparsify_matrix))(
153
+ parameters
219
154
  )
220
155
 
221
- @jax.jit
222
- @jax.vmap
223
- def obs_model(p):
224
- return transformed_model.apply(None, p)
225
-
226
156
  if apply_stat:
227
157
  with handlers.seed(rng_seed=rng_key):
228
158
  spectrum = numpyro.sample(
229
159
  f"likelihood_obs_{i}",
230
- numpyro.distributions.Poisson(obs_model(parameters)),
160
+ numpyro.distributions.Poisson(countrate),
231
161
  )
232
162
 
233
163
  else:
234
- spectrum = obs_model(parameters)
164
+ spectrum = countrate
235
165
 
236
166
  fakeits.append(spectrum)
237
167