jaxspec 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- import os
2
1
  import numpy as np
3
2
  import xarray as xr
3
+ from .ogip import DataPHA
4
4
 
5
5
 
6
6
  class Observation(xr.Dataset):
@@ -43,31 +43,40 @@ class Observation(xr.Dataset):
43
43
  attributes = {}
44
44
 
45
45
  if background is None:
46
- background = np.zeros_like(counts, dtype=int)
46
+ background = np.zeros_like(counts, dtype=np.int64)
47
47
 
48
48
  data_dict = {
49
- "counts": (["instrument_channel"], np.array(counts, dtype=np.int64), {"description": "Counts", "unit": "photons"}),
49
+ "counts": (["instrument_channel"], np.asarray(counts, dtype=np.int64), {"description": "Counts", "unit": "photons"}),
50
50
  "folded_counts": (
51
51
  ["folded_channel"],
52
- np.array(grouping @ counts, dtype=int),
52
+ np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64),
53
53
  {"description": "Folded counts, after grouping", "unit": "photons"},
54
54
  ),
55
55
  "grouping": (
56
56
  ["folded_channel", "instrument_channel"],
57
- np.array(grouping, dtype=bool),
57
+ grouping,
58
58
  {"description": "Grouping matrix."},
59
59
  ),
60
- "quality": (["instrument_channel"], np.array(quality, dtype=int), {"description": "Quality flag."}),
60
+ "quality": (["instrument_channel"], np.asarray(quality, dtype=np.int64), {"description": "Quality flag."}),
61
61
  "exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}),
62
- "backratio": ([], float(backratio), {"description": "Background scaling (SRC_BACKSCAL/BKG_BACKSCAL)"}),
62
+ "backratio": (
63
+ ["instrument_channel"],
64
+ np.asarray(backratio, dtype=float),
65
+ {"description": "Background scaling (SRC_BACKSCAL/BKG_BACKSCAL)"},
66
+ ),
67
+ "folded_backratio": (
68
+ ["folded_channel"],
69
+ np.asarray(np.ma.filled(grouping @ backratio), dtype=float),
70
+ {"description": "Background scaling after grouping"},
71
+ ),
63
72
  "background": (
64
73
  ["instrument_channel"],
65
- np.array(background, dtype=int),
74
+ np.asarray(background, dtype=np.int64),
66
75
  {"description": "Background counts", "unit": "photons"},
67
76
  ),
68
77
  "folded_background": (
69
78
  ["folded_channel"],
70
- np.array(grouping @ background, dtype=int),
79
+ np.asarray(np.ma.filled(grouping @ background), dtype=np.int64),
71
80
  {"description": "Background counts", "unit": "photons"},
72
81
  ),
73
82
  }
@@ -75,7 +84,7 @@ class Observation(xr.Dataset):
75
84
  return cls(
76
85
  data_dict,
77
86
  coords={
78
- "channel": (["instrument_channel"], np.array(channel, dtype=np.int64), {"description": "Channel number"}),
87
+ "channel": (["instrument_channel"], np.asarray(channel, dtype=np.int64), {"description": "Channel number"}),
79
88
  "grouped_channel": (
80
89
  ["folded_channel"],
81
90
  np.arange(len(grouping @ counts), dtype=np.int64),
@@ -86,15 +95,11 @@ class Observation(xr.Dataset):
86
95
  )
87
96
 
88
97
  @classmethod
89
- def from_pha_file(cls, pha_file: str | os.PathLike, **kwargs):
90
- from .util import data_loader
91
-
92
- pha, arf, rmf, bkg, metadata = data_loader(pha_file)
93
-
98
+ def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
94
99
  if bkg is not None:
95
- backratio = (pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal)
100
+ backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal))
96
101
  else:
97
- backratio = 1.0
102
+ backratio = np.ones_like(pha.counts)
98
103
 
99
104
  return cls.from_matrix(
100
105
  pha.counts,
@@ -107,6 +112,28 @@ class Observation(xr.Dataset):
107
112
  attributes=metadata,
108
113
  )
109
114
 
115
+ @classmethod
116
+ def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
117
+ from .util import data_path_finder
118
+
119
+ arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
120
+ bkg_path = bkg_path_default if bkg_path is None else bkg_path
121
+
122
+ pha = DataPHA.from_file(pha_path)
123
+ bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
124
+
125
+ if metadata is None:
126
+ metadata = {}
127
+
128
+ metadata.update(
129
+ observation_file=pha_path,
130
+ background_file=bkg_path,
131
+ response_matrix_file=rmf_path,
132
+ ancillary_response_file=arf_path,
133
+ )
134
+
135
+ return cls.from_ogip_container(pha, bkg=bkg, **metadata)
136
+
110
137
  def plot_counts(self, **kwargs):
111
138
  """
112
139
  Plot the counts
@@ -133,7 +160,7 @@ class Observation(xr.Dataset):
133
160
  ax = fig.add_subplot(gs[1, 0])
134
161
  ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
135
162
  ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
136
- sns.heatmap(self.grouping.T, ax=ax, cbar=False)
163
+ sns.heatmap(self.grouping.data.todense().T, ax=ax, cbar=False)
137
164
  ax_histx.step(np.arange(len(self.folded_counts)), self.folded_counts, where="post")
138
165
  ax_histy.step(self.counts, np.arange(len(self.counts)), where="post")
139
166
 
jaxspec/data/ogip.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import numpy as np
2
2
  import os
3
3
  import astropy.units as u
4
+ import sparse
4
5
  from astropy.table import QTable
5
6
  from astropy.io import fits
6
7
 
@@ -8,7 +9,6 @@ from astropy.io import fits
8
9
  class DataPHA:
9
10
  r"""
10
11
  Class to handle PHA data defined with OGIP standards.
11
-
12
12
  ??? info "References"
13
13
  * [The OGIP standard PHA file format](https://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/spectra/ogip_92_007/node5.html)
14
14
  """
@@ -26,30 +26,40 @@ class DataPHA:
26
26
  backscal=1.0,
27
27
  areascal=1.0,
28
28
  ):
29
- self.channel = channel
30
- self.counts = counts
31
- self.exposure = exposure
29
+ self.channel = np.asarray(channel, dtype=int)
30
+ self.counts = np.asarray(counts, dtype=int)
31
+ self.exposure = float(exposure)
32
32
 
33
- self.quality = quality
33
+ self.quality = np.asarray(quality, dtype=int)
34
34
  self.backfile = backfile
35
35
  self.respfile = respfile
36
36
  self.ancrfile = ancrfile
37
- self.backscal = backscal
38
- self.areascal = areascal
37
+ self.backscal = np.asarray(backscal, dtype=float)
38
+ self.areascal = np.asarray(areascal, dtype=float)
39
39
 
40
40
  if grouping is not None:
41
- # Indices array of beginning of each group
41
+ # Indices array of the beginning of each group
42
42
  b_grp = np.where(grouping == 1)[0]
43
- # Indices array of ending of each group
43
+ # Indices array of the ending of each group
44
44
  e_grp = np.hstack((b_grp[1:], len(channel)))
45
- # Matrix to multiply with counts/channel to have counts/group
46
- grp_matrix = np.zeros((len(b_grp), len(channel)), dtype=bool)
45
+
46
+ # Prepare data for sparse matrix
47
+ rows = []
48
+ cols = []
49
+ data = []
47
50
 
48
51
  for i in range(len(b_grp)):
49
- grp_matrix[i, b_grp[i] : e_grp[i]] = 1
52
+ for j in range(b_grp[i], e_grp[i]):
53
+ rows.append(i)
54
+ cols.append(j)
55
+ data.append(True)
56
+
57
+ # Create a COO sparse matrix
58
+ grp_matrix = sparse.COO((data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0)
50
59
 
51
60
  else:
52
- grp_matrix = np.eye(len(channel))
61
+ # Identity matrix case, use sparse for efficiency
62
+ grp_matrix = sparse.eye(len(channel), format="coo", dtype=bool)
53
63
 
54
64
  self.grouping = grp_matrix
55
65
 
@@ -65,24 +75,44 @@ class DataPHA:
65
75
  data = QTable.read(pha_file, "SPECTRUM")
66
76
  header = fits.getheader(pha_file, "SPECTRUM")
67
77
 
68
- if "QUALITY" in data.colnames:
78
+ if header.get("GROUPING") == 0:
79
+ grouping = None
80
+ elif "GROUPING" in data.colnames:
81
+ grouping = data["GROUPING"]
82
+ else:
83
+ raise ValueError("No grouping column found in the PHA file.")
84
+
85
+ if header.get("QUALITY") == 0:
86
+ quality = np.zeros(len(data["CHANNEL"]), dtype=bool)
87
+ elif "QUALITY" in data.colnames:
69
88
  quality = data["QUALITY"]
70
89
  else:
71
- if header.get("QUALITY") == 0:
72
- quality = np.ones(len(data["CHANNEL"]), dtype=bool)
73
- else:
74
- raise ValueError("No quality column found in the PHA file.")
90
+ raise ValueError("No QUALITY column found in the PHA file.")
91
+
92
+ if "BACKSCAL" in header:
93
+ backscal = header["BACKSCAL"] * np.ones_like(data["CHANNEL"])
94
+ elif "BACKSCAL" in data.colnames:
95
+ backscal = data["BACKSCAL"]
96
+ else:
97
+ raise ValueError("No BACKSCAL found in the PHA file.")
98
+
99
+ if "AREASCAL" in header:
100
+ areascal = header["AREASCAL"]
101
+ elif "AREASCAL" in data.colnames:
102
+ areascal = data["AREASCAL"]
103
+ else:
104
+ raise ValueError("No AREASCAL found in the PHA file.")
75
105
 
76
106
  # Grouping and quality parameters are in binned PHA dataset
77
107
  # Backfile, respfile and ancrfile are in primary header
78
108
  kwargs = {
79
- "grouping": data["GROUPING"] if "GROUPING" in data.colnames else None,
109
+ "grouping": grouping,
80
110
  "quality": quality,
81
111
  "backfile": header.get("BACKFILE"),
82
112
  "respfile": header.get("RESPFILE"),
83
113
  "ancrfile": header.get("ANCRFILE"),
84
- "backscal": header.get("BACKSCAL", 1.0),
85
- "areascal": header.get("AREASCAL", 1.0),
114
+ "backscal": backscal,
115
+ "areascal": areascal,
86
116
  }
87
117
 
88
118
  return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
@@ -123,18 +153,16 @@ class DataARF:
123
153
  class DataRMF:
124
154
  r"""
125
155
  Class to handle RMF data defined with OGIP standards.
126
-
127
156
  ??? info "References"
128
157
  * [The Calibration Requirements for Spectral Analysis (Definition of RMF and ARF file formats)](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html)
129
158
  * [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
130
-
131
159
  """
132
160
 
133
- def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max):
161
+ def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max, low_threshold=0.0):
134
162
  # RMF stuff
135
163
  self.energ_lo = energ_lo # "Entry" energies
136
164
  self.energ_hi = energ_hi # "Entry" energies
137
- self.n_grp = n_grp # "Entry" energies
165
+ self.n_grp = n_grp
138
166
  self.f_chan = f_chan
139
167
  self.n_chan = n_chan
140
168
  self.matrix_entry = matrix
@@ -144,35 +172,51 @@ class DataRMF:
144
172
  self.e_min = e_min
145
173
  self.e_max = e_max
146
174
 
147
- self.full_matrix = np.zeros(self.n_grp.shape + self.channel.shape)
175
+ # Prepare data for sparse matrix
176
+ rows = []
177
+ cols = []
178
+ data = []
148
179
 
149
- for i, n_grp in enumerate(self.n_grp):
180
+ for i, n_grp_val in enumerate(self.n_grp):
150
181
  base = 0
151
182
 
152
183
  if np.size(self.f_chan[i]) == 1:
153
- # ravel()[0] allows to get the value of the array without triggering numpy's conversion from
154
- # multidimensional array to scalar
155
184
  low = int(self.f_chan[i].ravel()[0])
156
185
  high = min(
157
186
  int(self.f_chan[i].ravel()[0] + self.n_chan[i].ravel()[0]),
158
- self.full_matrix.shape[1],
187
+ len(self.channel),
159
188
  )
160
189
 
161
- self.full_matrix[i, low:high] = self.matrix_entry[i][0 : high - low]
190
+ rows.extend([i] * (high - low))
191
+ cols.extend(range(low, high))
192
+ data.extend(self.matrix_entry[i][0 : high - low])
162
193
 
163
194
  else:
164
- for j in range(n_grp):
195
+ for j in range(n_grp_val):
165
196
  low = self.f_chan[i][j]
166
- high = min(self.f_chan[i][j] + self.n_chan[i][j], self.full_matrix.shape[1])
197
+ high = min(self.f_chan[i][j] + self.n_chan[i][j], len(self.channel))
167
198
 
168
- self.full_matrix[i, low:high] = self.matrix_entry[i][base : base + self.n_chan[i][j]]
199
+ rows.extend([i] * (high - low))
200
+ cols.extend(range(low, high))
201
+ data.extend(self.matrix_entry[i][base : base + self.n_chan[i][j]])
169
202
 
170
203
  base += self.n_chan[i][j]
171
204
 
172
- # Transposed matrix so that we just have to multiply by the spectrum
173
- self.matrix = self.full_matrix.T
174
- # self.matrix = bsr_matrix(self.full_matrix.T).T
175
- # self.sparse_matrix = sparse.BCOO.fromdense(jnp.copy(self.full_matrix))
205
+ # Convert lists to numpy arrays
206
+ rows = np.array(rows)
207
+ cols = np.array(cols)
208
+ data = np.array(data)
209
+
210
+ # Sometimes, zero elements are given in the matrix rows, so we get rid of them
211
+ idxs = data > low_threshold
212
+
213
+ # Create a COO sparse matrix and then convert to CSR for efficiency
214
+ coo = sparse.COO([rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel)))
215
+ self.sparse_matrix = coo.T # .tocsr()
216
+
217
+ @property
218
+ def matrix(self):
219
+ return np.asarray(self.sparse_matrix.todense())
176
220
 
177
221
  @classmethod
178
222
  def from_file(cls, rmf_file: str | os.PathLike):
@@ -182,15 +226,31 @@ class DataRMF:
182
226
  Parameters:
183
227
  rmf_file: The RMF file path.
184
228
  """
229
+ extension_names = [hdu[1] for hdu in fits.info(rmf_file, output=False)]
230
+
231
+ if "MATRIX" in extension_names:
232
+ matrix_extension = "MATRIX"
185
233
 
186
- matrix_table = QTable.read(rmf_file, "MATRIX")
234
+ elif "SPECRESP MATRIX" in extension_names:
235
+ matrix_extension = "SPECRESP MATRIX"
236
+ # raise NotImplementedError("SPECRESP MATRIX extension is not yet supported")
237
+
238
+ else:
239
+ raise ValueError("No MATRIX or SPECRESP MATRIX extension found in the RMF file")
240
+
241
+ matrix_table = QTable.read(rmf_file, matrix_extension)
187
242
  ebounds_table = QTable.read(rmf_file, "EBOUNDS")
188
243
 
244
+ matrix_header = fits.getheader(rmf_file, matrix_extension)
245
+
246
+ f_chan_column_pos = list(matrix_table.columns).index("F_CHAN") + 1
247
+ tlmin_fchan = int(matrix_header[f"TLMIN{f_chan_column_pos}"])
248
+
189
249
  return cls(
190
250
  matrix_table["ENERG_LO"],
191
251
  matrix_table["ENERG_HI"],
192
252
  matrix_table["N_GRP"],
193
- matrix_table["F_CHAN"],
253
+ matrix_table["F_CHAN"] - tlmin_fchan,
194
254
  matrix_table["N_CHAN"],
195
255
  matrix_table["MATRIX"],
196
256
  ebounds_table["CHANNEL"],
jaxspec/data/util.py CHANGED
@@ -1,15 +1,14 @@
1
1
  import importlib.resources
2
- import os
3
- import warnings
4
2
  import numpyro
5
3
  import jax
6
4
  import numpy as np
7
5
  import haiku as hk
6
+ from pathlib import Path
8
7
  from numpy.typing import ArrayLike
9
8
  from collections.abc import Mapping
10
- from typing import TypeVar
9
+ from typing import TypeVar, Tuple
10
+ from astropy.io import fits
11
11
 
12
- from .ogip import DataPHA, DataARF, DataRMF
13
12
  from . import Observation, Instrument, ObsConfiguration
14
13
  from ..model.abc import SpectralModel
15
14
  from ..fit import CountForwardModel
@@ -26,17 +25,17 @@ def load_example_observations():
26
25
 
27
26
  example_observations = {
28
27
  "PN": Observation.from_pha_file(
29
- importlib.resources.files("jaxspec") / "data/example_data/PN_spectrum_grp20.fits",
28
+ str(importlib.resources.files("jaxspec") / "data/example_data/PN_spectrum_grp20.fits"),
30
29
  low_energy=0.3,
31
30
  high_energy=7.5,
32
31
  ),
33
32
  "MOS1": Observation.from_pha_file(
34
- importlib.resources.files("jaxspec") / "data/example_data/MOS1_spectrum_grp.fits",
33
+ str(importlib.resources.files("jaxspec") / "data/example_data/MOS1_spectrum_grp.fits"),
35
34
  low_energy=0.3,
36
35
  high_energy=7,
37
36
  ),
38
37
  "MOS2": Observation.from_pha_file(
39
- importlib.resources.files("jaxspec") / "data/example_data/MOS2_spectrum_grp.fits",
38
+ str(importlib.resources.files("jaxspec") / "data/example_data/MOS2_spectrum_grp.fits"),
40
39
  low_energy=0.3,
41
40
  high_energy=7,
42
41
  ),
@@ -52,16 +51,16 @@ def load_example_instruments():
52
51
 
53
52
  example_instruments = {
54
53
  "PN": Instrument.from_ogip_file(
55
- importlib.resources.files("jaxspec") / "data/example_data/PN.arf",
56
- importlib.resources.files("jaxspec") / "data/example_data/PN.rmf",
54
+ str(importlib.resources.files("jaxspec") / "data/example_data/PN.rmf"),
55
+ str(importlib.resources.files("jaxspec") / "data/example_data/PN.arf"),
57
56
  ),
58
57
  "MOS1": Instrument.from_ogip_file(
59
- importlib.resources.files("jaxspec") / "data/example_data/MOS1.arf",
60
- importlib.resources.files("jaxspec") / "data/example_data/MOS1.rmf",
58
+ str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.rmf"),
59
+ str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.arf"),
61
60
  ),
62
61
  "MOS2": Instrument.from_ogip_file(
63
- importlib.resources.files("jaxspec") / "data/example_data/MOS2.arf",
64
- importlib.resources.files("jaxspec") / "data/example_data/MOS2.rmf",
62
+ str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.rmf"),
63
+ str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.arf"),
65
64
  ),
66
65
  }
67
66
 
@@ -202,46 +201,51 @@ def fakeit_for_multiple_parameters(
202
201
  return fakeits[0] if len(fakeits) == 1 else fakeits
203
202
 
204
203
 
205
- def data_loader(pha_path, arf_path=None, rmf_path=None, bkg_path=None):
204
+ def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
206
205
  """
207
- This function is a convenience function that allows to load PHA, ARF and RMF data
208
- from a given PHA file, using either the ARF/RMF/BKG filenames in the header or the
209
- specified filenames overwritten by the user.
210
-
206
+ This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
211
207
  Parameters:
212
208
  pha_path: The PHA file path.
209
+
210
+ Returns:
213
211
  arf_path: The ARF file path.
214
212
  rmf_path: The RMF file path.
215
213
  bkg_path: The BKG file path.
216
214
  """
217
215
 
218
- pha = DataPHA.from_file(pha_path)
219
-
220
- if arf_path is None:
221
- if pha.ancrfile != "none" and pha.ancrfile != "":
222
- arf_path = os.path.join(os.path.dirname(pha_path), pha.ancrfile)
223
- if rmf_path is None:
224
- if pha.respfile != "none" and pha.respfile != "":
225
- rmf_path = os.path.join(os.path.dirname(pha_path), pha.respfile)
226
- if bkg_path is None:
227
- if pha.backfile != "none" and pha.backfile != "":
228
- bkg_path = os.path.join(os.path.dirname(pha_path), pha.backfile)
229
-
230
- arf = DataARF.from_file(arf_path) if arf_path is not None else None
231
- rmf = DataRMF.from_file(rmf_path) if rmf_path is not None else None
232
-
233
- try:
234
- bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
235
-
236
- except FileNotFoundError:
237
- bkg = None
238
- warnings.warn(f"Background file {bkg_path} is specified in header but not found.")
239
-
240
- metadata = {
241
- "observation_file": pha_path,
242
- "background_file": bkg_path,
243
- "response_matrix_file": rmf_path,
244
- "ancillary_response_file": arf_path,
245
- }
216
+ def find_path(file_name: str, directory: str) -> str | None:
217
+ if file_name.lower() != "none" and file_name != "":
218
+ return find_file_or_compressed_in_dir(file_name, directory)
219
+ else:
220
+ return None
221
+
222
+ header = fits.getheader(pha_path, "SPECTRUM")
223
+ directory = str(Path(pha_path).parent)
224
+
225
+ arf_path = find_path(header.get("ANCRFILE", "none"), directory)
226
+ rmf_path = find_path(header.get("RESPFILE", "none"), directory)
227
+ bkg_path = find_path(header.get("BACKFILE", "none"), directory)
228
+
229
+ return arf_path, rmf_path, bkg_path
230
+
231
+
232
+ def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
233
+ """
234
+ Try to find a file or its .gz compressed version in a given directory and return
235
+ the full path of the file.
236
+ """
237
+ path = Path(path) if isinstance(path, str) else path
238
+ directory = Path(directory) if isinstance(directory, str) else directory
239
+
240
+ if directory.joinpath(path).exists():
241
+ return str(directory.joinpath(path))
242
+
243
+ matching_files = list(directory.glob(str(path) + "*"))
244
+
245
+ if matching_files:
246
+ file = matching_files[0]
247
+ if file.suffix == ".gz":
248
+ return str(file)
246
249
 
247
- return pha, arf, rmf, bkg, metadata
250
+ else:
251
+ raise FileNotFoundError(f"Can't find {path}(.gz) in {directory}.")
jaxspec/fit.py CHANGED
@@ -7,6 +7,7 @@ from typing import Callable, TypeVar
7
7
  from abc import ABC
8
8
  from jax import random
9
9
  from jax.tree_util import tree_map
10
+ from jax.experimental.sparse import BCSR
10
11
  from .analysis.results import ChainResult
11
12
  from .model.abc import SpectralModel
12
13
  from .data import ObsConfiguration
@@ -45,16 +46,19 @@ def build_numpyro_model(
45
46
  model: SpectralModel,
46
47
  background_model: BackgroundModel,
47
48
  name: str = "",
49
+ sparse: bool = False,
48
50
  ) -> Callable:
49
51
  def numpro_model(prior_params, observed=True):
50
52
  # prior_params = build_prior(prior_distributions, name=name)
51
- transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs)(par)))
53
+ transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par)))
52
54
 
53
55
  if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
54
- # TODO : Raise warning when setting a background model but there is no background spectra loaded
55
56
  bkg_countrate = background_model.numpyro_model(
56
57
  obs.out_energies, obs.folded_background.data, name=name + "bkg", observed=observed
57
58
  )
59
+ elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
60
+ raise ValueError("Trying to fit a background model but no background is linked to this observation")
61
+
58
62
  else:
59
63
  bkg_countrate = 0.0
60
64
 
@@ -65,7 +69,7 @@ def build_numpyro_model(
65
69
  with numpyro.plate(name + "obs_plate", len(obs.folded_counts)):
66
70
  numpyro.sample(
67
71
  name + "obs",
68
- Poisson(countrate + bkg_countrate * obs.backratio.data),
72
+ Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
69
73
  obs=obs.folded_counts.data if observed else None,
70
74
  )
71
75
 
@@ -77,11 +81,16 @@ class CountForwardModel(hk.Module):
77
81
  A haiku module which allows to build the function that simulates the measured counts
78
82
  """
79
83
 
80
- def __init__(self, model: SpectralModel, folding: ObsConfiguration):
84
+ def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False):
81
85
  super().__init__()
82
86
  self.model = model
83
87
  self.energies = jnp.asarray(folding.in_energies)
84
- self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data)
88
+
89
+ if sparse: # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
90
+ self.transfer_matrix = BCSR.from_scipy_sparse(folding.transfer_matrix.data.to_scipy_sparse().tocsr()) #
91
+
92
+ else:
93
+ self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
85
94
 
86
95
  def __call__(self, parameters):
87
96
  """
@@ -119,7 +128,7 @@ class BayesianModelAbstract(ABC):
119
128
  num_samples: int = 1000,
120
129
  max_tree_depth: int = 10,
121
130
  target_accept_prob: float = 0.8,
122
- dense_mass=True,
131
+ dense_mass=False,
123
132
  mcmc_kwargs: dict = {},
124
133
  ) -> ChainResult:
125
134
  """
@@ -184,13 +193,14 @@ class BayesianModelAbstract(ABC):
184
193
 
185
194
  class BayesianModel(BayesianModelAbstract):
186
195
  """
187
- Class to fit a model to a given set of observation using a Bayesian approach.
196
+ Class to fit a model to a given observation using a Bayesian approach.
188
197
  """
189
198
 
190
- def __init__(self, model, observation, background_model: BackgroundModel = None):
199
+ def __init__(self, model, observation, background_model: BackgroundModel = None, sparsify_matrix: bool = False):
191
200
  super().__init__(model)
192
201
  self.observation = observation
193
202
  self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
203
+ self.sparse = sparsify_matrix
194
204
  self.background_model = background_model
195
205
 
196
206
  def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
@@ -204,7 +214,7 @@ class BayesianModel(BayesianModelAbstract):
204
214
 
205
215
  def model(observed=True):
206
216
  prior_params = build_prior(prior_distributions)
207
- obs_model = build_numpyro_model(self.observation, self.model, self.background_model)
217
+ obs_model = build_numpyro_model(self.observation, self.model, self.background_model, sparse=self.sparse)
208
218
  obs_model(prior_params, observed=observed)
209
219
 
210
220
  return model