jaxspec 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +1 -1
- jaxspec/data/__init__.py +3 -0
- jaxspec/data/example_data/fakeit.pha +335 -1
- jaxspec/data/instrument.py +14 -9
- jaxspec/data/obsconf.py +109 -51
- jaxspec/data/observation.py +45 -18
- jaxspec/data/ogip.py +100 -40
- jaxspec/data/util.py +51 -47
- jaxspec/fit.py +19 -9
- jaxspec/model/abc.py +29 -6
- jaxspec/model/additive.py +87 -22
- jaxspec/model/background.py +5 -5
- jaxspec/model/multiplicative.py +56 -15
- {jaxspec-0.0.2.dist-info → jaxspec-0.0.4.dist-info}/METADATA +8 -4
- {jaxspec-0.0.2.dist-info → jaxspec-0.0.4.dist-info}/RECORD +17 -16
- {jaxspec-0.0.2.dist-info → jaxspec-0.0.4.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.2.dist-info → jaxspec-0.0.4.dist-info}/WHEEL +0 -0
jaxspec/data/observation.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import numpy as np
|
|
3
2
|
import xarray as xr
|
|
3
|
+
from .ogip import DataPHA
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class Observation(xr.Dataset):
|
|
@@ -43,31 +43,40 @@ class Observation(xr.Dataset):
|
|
|
43
43
|
attributes = {}
|
|
44
44
|
|
|
45
45
|
if background is None:
|
|
46
|
-
background = np.zeros_like(counts, dtype=
|
|
46
|
+
background = np.zeros_like(counts, dtype=np.int64)
|
|
47
47
|
|
|
48
48
|
data_dict = {
|
|
49
|
-
"counts": (["instrument_channel"], np.
|
|
49
|
+
"counts": (["instrument_channel"], np.asarray(counts, dtype=np.int64), {"description": "Counts", "unit": "photons"}),
|
|
50
50
|
"folded_counts": (
|
|
51
51
|
["folded_channel"],
|
|
52
|
-
np.
|
|
52
|
+
np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64),
|
|
53
53
|
{"description": "Folded counts, after grouping", "unit": "photons"},
|
|
54
54
|
),
|
|
55
55
|
"grouping": (
|
|
56
56
|
["folded_channel", "instrument_channel"],
|
|
57
|
-
|
|
57
|
+
grouping,
|
|
58
58
|
{"description": "Grouping matrix."},
|
|
59
59
|
),
|
|
60
|
-
"quality": (["instrument_channel"], np.
|
|
60
|
+
"quality": (["instrument_channel"], np.asarray(quality, dtype=np.int64), {"description": "Quality flag."}),
|
|
61
61
|
"exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}),
|
|
62
|
-
"backratio": (
|
|
62
|
+
"backratio": (
|
|
63
|
+
["instrument_channel"],
|
|
64
|
+
np.asarray(backratio, dtype=float),
|
|
65
|
+
{"description": "Background scaling (SRC_BACKSCAL/BKG_BACKSCAL)"},
|
|
66
|
+
),
|
|
67
|
+
"folded_backratio": (
|
|
68
|
+
["folded_channel"],
|
|
69
|
+
np.asarray(np.ma.filled(grouping @ backratio), dtype=float),
|
|
70
|
+
{"description": "Background scaling after grouping"},
|
|
71
|
+
),
|
|
63
72
|
"background": (
|
|
64
73
|
["instrument_channel"],
|
|
65
|
-
np.
|
|
74
|
+
np.asarray(background, dtype=np.int64),
|
|
66
75
|
{"description": "Background counts", "unit": "photons"},
|
|
67
76
|
),
|
|
68
77
|
"folded_background": (
|
|
69
78
|
["folded_channel"],
|
|
70
|
-
np.
|
|
79
|
+
np.asarray(np.ma.filled(grouping @ background), dtype=np.int64),
|
|
71
80
|
{"description": "Background counts", "unit": "photons"},
|
|
72
81
|
),
|
|
73
82
|
}
|
|
@@ -75,7 +84,7 @@ class Observation(xr.Dataset):
|
|
|
75
84
|
return cls(
|
|
76
85
|
data_dict,
|
|
77
86
|
coords={
|
|
78
|
-
"channel": (["instrument_channel"], np.
|
|
87
|
+
"channel": (["instrument_channel"], np.asarray(channel, dtype=np.int64), {"description": "Channel number"}),
|
|
79
88
|
"grouped_channel": (
|
|
80
89
|
["folded_channel"],
|
|
81
90
|
np.arange(len(grouping @ counts), dtype=np.int64),
|
|
@@ -86,15 +95,11 @@ class Observation(xr.Dataset):
|
|
|
86
95
|
)
|
|
87
96
|
|
|
88
97
|
@classmethod
|
|
89
|
-
def
|
|
90
|
-
from .util import data_loader
|
|
91
|
-
|
|
92
|
-
pha, arf, rmf, bkg, metadata = data_loader(pha_file)
|
|
93
|
-
|
|
98
|
+
def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
|
|
94
99
|
if bkg is not None:
|
|
95
|
-
backratio = (pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal)
|
|
100
|
+
backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal))
|
|
96
101
|
else:
|
|
97
|
-
backratio =
|
|
102
|
+
backratio = np.ones_like(pha.counts)
|
|
98
103
|
|
|
99
104
|
return cls.from_matrix(
|
|
100
105
|
pha.counts,
|
|
@@ -107,6 +112,28 @@ class Observation(xr.Dataset):
|
|
|
107
112
|
attributes=metadata,
|
|
108
113
|
)
|
|
109
114
|
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
|
|
117
|
+
from .util import data_path_finder
|
|
118
|
+
|
|
119
|
+
arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
|
|
120
|
+
bkg_path = bkg_path_default if bkg_path is None else bkg_path
|
|
121
|
+
|
|
122
|
+
pha = DataPHA.from_file(pha_path)
|
|
123
|
+
bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
|
|
124
|
+
|
|
125
|
+
if metadata is None:
|
|
126
|
+
metadata = {}
|
|
127
|
+
|
|
128
|
+
metadata.update(
|
|
129
|
+
observation_file=pha_path,
|
|
130
|
+
background_file=bkg_path,
|
|
131
|
+
response_matrix_file=rmf_path,
|
|
132
|
+
ancillary_response_file=arf_path,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return cls.from_ogip_container(pha, bkg=bkg, **metadata)
|
|
136
|
+
|
|
110
137
|
def plot_counts(self, **kwargs):
|
|
111
138
|
"""
|
|
112
139
|
Plot the counts
|
|
@@ -133,7 +160,7 @@ class Observation(xr.Dataset):
|
|
|
133
160
|
ax = fig.add_subplot(gs[1, 0])
|
|
134
161
|
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
|
|
135
162
|
ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
|
|
136
|
-
sns.heatmap(self.grouping.T, ax=ax, cbar=False)
|
|
163
|
+
sns.heatmap(self.grouping.data.todense().T, ax=ax, cbar=False)
|
|
137
164
|
ax_histx.step(np.arange(len(self.folded_counts)), self.folded_counts, where="post")
|
|
138
165
|
ax_histy.step(self.counts, np.arange(len(self.counts)), where="post")
|
|
139
166
|
|
jaxspec/data/ogip.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import os
|
|
3
3
|
import astropy.units as u
|
|
4
|
+
import sparse
|
|
4
5
|
from astropy.table import QTable
|
|
5
6
|
from astropy.io import fits
|
|
6
7
|
|
|
@@ -8,7 +9,6 @@ from astropy.io import fits
|
|
|
8
9
|
class DataPHA:
|
|
9
10
|
r"""
|
|
10
11
|
Class to handle PHA data defined with OGIP standards.
|
|
11
|
-
|
|
12
12
|
??? info "References"
|
|
13
13
|
* [The OGIP standard PHA file format](https://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/spectra/ogip_92_007/node5.html)
|
|
14
14
|
"""
|
|
@@ -26,30 +26,40 @@ class DataPHA:
|
|
|
26
26
|
backscal=1.0,
|
|
27
27
|
areascal=1.0,
|
|
28
28
|
):
|
|
29
|
-
self.channel = channel
|
|
30
|
-
self.counts = counts
|
|
31
|
-
self.exposure = exposure
|
|
29
|
+
self.channel = np.asarray(channel, dtype=int)
|
|
30
|
+
self.counts = np.asarray(counts, dtype=int)
|
|
31
|
+
self.exposure = float(exposure)
|
|
32
32
|
|
|
33
|
-
self.quality = quality
|
|
33
|
+
self.quality = np.asarray(quality, dtype=int)
|
|
34
34
|
self.backfile = backfile
|
|
35
35
|
self.respfile = respfile
|
|
36
36
|
self.ancrfile = ancrfile
|
|
37
|
-
self.backscal = backscal
|
|
38
|
-
self.areascal = areascal
|
|
37
|
+
self.backscal = np.asarray(backscal, dtype=float)
|
|
38
|
+
self.areascal = np.asarray(areascal, dtype=float)
|
|
39
39
|
|
|
40
40
|
if grouping is not None:
|
|
41
|
-
# Indices array of beginning of each group
|
|
41
|
+
# Indices array of the beginning of each group
|
|
42
42
|
b_grp = np.where(grouping == 1)[0]
|
|
43
|
-
# Indices array of ending of each group
|
|
43
|
+
# Indices array of the ending of each group
|
|
44
44
|
e_grp = np.hstack((b_grp[1:], len(channel)))
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
|
|
46
|
+
# Prepare data for sparse matrix
|
|
47
|
+
rows = []
|
|
48
|
+
cols = []
|
|
49
|
+
data = []
|
|
47
50
|
|
|
48
51
|
for i in range(len(b_grp)):
|
|
49
|
-
|
|
52
|
+
for j in range(b_grp[i], e_grp[i]):
|
|
53
|
+
rows.append(i)
|
|
54
|
+
cols.append(j)
|
|
55
|
+
data.append(True)
|
|
56
|
+
|
|
57
|
+
# Create a COO sparse matrix
|
|
58
|
+
grp_matrix = sparse.COO((data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0)
|
|
50
59
|
|
|
51
60
|
else:
|
|
52
|
-
|
|
61
|
+
# Identity matrix case, use sparse for efficiency
|
|
62
|
+
grp_matrix = sparse.eye(len(channel), format="coo", dtype=bool)
|
|
53
63
|
|
|
54
64
|
self.grouping = grp_matrix
|
|
55
65
|
|
|
@@ -65,24 +75,44 @@ class DataPHA:
|
|
|
65
75
|
data = QTable.read(pha_file, "SPECTRUM")
|
|
66
76
|
header = fits.getheader(pha_file, "SPECTRUM")
|
|
67
77
|
|
|
68
|
-
if "
|
|
78
|
+
if header.get("GROUPING") == 0:
|
|
79
|
+
grouping = None
|
|
80
|
+
elif "GROUPING" in data.colnames:
|
|
81
|
+
grouping = data["GROUPING"]
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError("No grouping column found in the PHA file.")
|
|
84
|
+
|
|
85
|
+
if header.get("QUALITY") == 0:
|
|
86
|
+
quality = np.zeros(len(data["CHANNEL"]), dtype=bool)
|
|
87
|
+
elif "QUALITY" in data.colnames:
|
|
69
88
|
quality = data["QUALITY"]
|
|
70
89
|
else:
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
90
|
+
raise ValueError("No QUALITY column found in the PHA file.")
|
|
91
|
+
|
|
92
|
+
if "BACKSCAL" in header:
|
|
93
|
+
backscal = header["BACKSCAL"] * np.ones_like(data["CHANNEL"])
|
|
94
|
+
elif "BACKSCAL" in data.colnames:
|
|
95
|
+
backscal = data["BACKSCAL"]
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError("No BACKSCAL found in the PHA file.")
|
|
98
|
+
|
|
99
|
+
if "AREASCAL" in header:
|
|
100
|
+
areascal = header["AREASCAL"]
|
|
101
|
+
elif "AREASCAL" in data.colnames:
|
|
102
|
+
areascal = data["AREASCAL"]
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError("No AREASCAL found in the PHA file.")
|
|
75
105
|
|
|
76
106
|
# Grouping and quality parameters are in binned PHA dataset
|
|
77
107
|
# Backfile, respfile and ancrfile are in primary header
|
|
78
108
|
kwargs = {
|
|
79
|
-
"grouping":
|
|
109
|
+
"grouping": grouping,
|
|
80
110
|
"quality": quality,
|
|
81
111
|
"backfile": header.get("BACKFILE"),
|
|
82
112
|
"respfile": header.get("RESPFILE"),
|
|
83
113
|
"ancrfile": header.get("ANCRFILE"),
|
|
84
|
-
"backscal":
|
|
85
|
-
"areascal":
|
|
114
|
+
"backscal": backscal,
|
|
115
|
+
"areascal": areascal,
|
|
86
116
|
}
|
|
87
117
|
|
|
88
118
|
return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
|
|
@@ -123,18 +153,16 @@ class DataARF:
|
|
|
123
153
|
class DataRMF:
|
|
124
154
|
r"""
|
|
125
155
|
Class to handle RMF data defined with OGIP standards.
|
|
126
|
-
|
|
127
156
|
??? info "References"
|
|
128
157
|
* [The Calibration Requirements for Spectral Analysis (Definition of RMF and ARF file formats)](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html)
|
|
129
158
|
* [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
|
|
130
|
-
|
|
131
159
|
"""
|
|
132
160
|
|
|
133
|
-
def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max):
|
|
161
|
+
def __init__(self, energ_lo, energ_hi, n_grp, f_chan, n_chan, matrix, channel, e_min, e_max, low_threshold=0.0):
|
|
134
162
|
# RMF stuff
|
|
135
163
|
self.energ_lo = energ_lo # "Entry" energies
|
|
136
164
|
self.energ_hi = energ_hi # "Entry" energies
|
|
137
|
-
self.n_grp = n_grp
|
|
165
|
+
self.n_grp = n_grp
|
|
138
166
|
self.f_chan = f_chan
|
|
139
167
|
self.n_chan = n_chan
|
|
140
168
|
self.matrix_entry = matrix
|
|
@@ -144,35 +172,51 @@ class DataRMF:
|
|
|
144
172
|
self.e_min = e_min
|
|
145
173
|
self.e_max = e_max
|
|
146
174
|
|
|
147
|
-
|
|
175
|
+
# Prepare data for sparse matrix
|
|
176
|
+
rows = []
|
|
177
|
+
cols = []
|
|
178
|
+
data = []
|
|
148
179
|
|
|
149
|
-
for i,
|
|
180
|
+
for i, n_grp_val in enumerate(self.n_grp):
|
|
150
181
|
base = 0
|
|
151
182
|
|
|
152
183
|
if np.size(self.f_chan[i]) == 1:
|
|
153
|
-
# ravel()[0] allows to get the value of the array without triggering numpy's conversion from
|
|
154
|
-
# multidimensional array to scalar
|
|
155
184
|
low = int(self.f_chan[i].ravel()[0])
|
|
156
185
|
high = min(
|
|
157
186
|
int(self.f_chan[i].ravel()[0] + self.n_chan[i].ravel()[0]),
|
|
158
|
-
self.
|
|
187
|
+
len(self.channel),
|
|
159
188
|
)
|
|
160
189
|
|
|
161
|
-
|
|
190
|
+
rows.extend([i] * (high - low))
|
|
191
|
+
cols.extend(range(low, high))
|
|
192
|
+
data.extend(self.matrix_entry[i][0 : high - low])
|
|
162
193
|
|
|
163
194
|
else:
|
|
164
|
-
for j in range(
|
|
195
|
+
for j in range(n_grp_val):
|
|
165
196
|
low = self.f_chan[i][j]
|
|
166
|
-
high = min(self.f_chan[i][j] + self.n_chan[i][j], self.
|
|
197
|
+
high = min(self.f_chan[i][j] + self.n_chan[i][j], len(self.channel))
|
|
167
198
|
|
|
168
|
-
|
|
199
|
+
rows.extend([i] * (high - low))
|
|
200
|
+
cols.extend(range(low, high))
|
|
201
|
+
data.extend(self.matrix_entry[i][base : base + self.n_chan[i][j]])
|
|
169
202
|
|
|
170
203
|
base += self.n_chan[i][j]
|
|
171
204
|
|
|
172
|
-
#
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
205
|
+
# Convert lists to numpy arrays
|
|
206
|
+
rows = np.array(rows)
|
|
207
|
+
cols = np.array(cols)
|
|
208
|
+
data = np.array(data)
|
|
209
|
+
|
|
210
|
+
# Sometimes, zero elements are given in the matrix rows, so we get rid of them
|
|
211
|
+
idxs = data > low_threshold
|
|
212
|
+
|
|
213
|
+
# Create a COO sparse matrix and then convert to CSR for efficiency
|
|
214
|
+
coo = sparse.COO([rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel)))
|
|
215
|
+
self.sparse_matrix = coo.T # .tocsr()
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def matrix(self):
|
|
219
|
+
return np.asarray(self.sparse_matrix.todense())
|
|
176
220
|
|
|
177
221
|
@classmethod
|
|
178
222
|
def from_file(cls, rmf_file: str | os.PathLike):
|
|
@@ -182,15 +226,31 @@ class DataRMF:
|
|
|
182
226
|
Parameters:
|
|
183
227
|
rmf_file: The RMF file path.
|
|
184
228
|
"""
|
|
229
|
+
extension_names = [hdu[1] for hdu in fits.info(rmf_file, output=False)]
|
|
230
|
+
|
|
231
|
+
if "MATRIX" in extension_names:
|
|
232
|
+
matrix_extension = "MATRIX"
|
|
185
233
|
|
|
186
|
-
|
|
234
|
+
elif "SPECRESP MATRIX" in extension_names:
|
|
235
|
+
matrix_extension = "SPECRESP MATRIX"
|
|
236
|
+
# raise NotImplementedError("SPECRESP MATRIX extension is not yet supported")
|
|
237
|
+
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError("No MATRIX or SPECRESP MATRIX extension found in the RMF file")
|
|
240
|
+
|
|
241
|
+
matrix_table = QTable.read(rmf_file, matrix_extension)
|
|
187
242
|
ebounds_table = QTable.read(rmf_file, "EBOUNDS")
|
|
188
243
|
|
|
244
|
+
matrix_header = fits.getheader(rmf_file, matrix_extension)
|
|
245
|
+
|
|
246
|
+
f_chan_column_pos = list(matrix_table.columns).index("F_CHAN") + 1
|
|
247
|
+
tlmin_fchan = int(matrix_header[f"TLMIN{f_chan_column_pos}"])
|
|
248
|
+
|
|
189
249
|
return cls(
|
|
190
250
|
matrix_table["ENERG_LO"],
|
|
191
251
|
matrix_table["ENERG_HI"],
|
|
192
252
|
matrix_table["N_GRP"],
|
|
193
|
-
matrix_table["F_CHAN"],
|
|
253
|
+
matrix_table["F_CHAN"] - tlmin_fchan,
|
|
194
254
|
matrix_table["N_CHAN"],
|
|
195
255
|
matrix_table["MATRIX"],
|
|
196
256
|
ebounds_table["CHANNEL"],
|
jaxspec/data/util.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
|
-
import os
|
|
3
|
-
import warnings
|
|
4
2
|
import numpyro
|
|
5
3
|
import jax
|
|
6
4
|
import numpy as np
|
|
7
5
|
import haiku as hk
|
|
6
|
+
from pathlib import Path
|
|
8
7
|
from numpy.typing import ArrayLike
|
|
9
8
|
from collections.abc import Mapping
|
|
10
|
-
from typing import TypeVar
|
|
9
|
+
from typing import TypeVar, Tuple
|
|
10
|
+
from astropy.io import fits
|
|
11
11
|
|
|
12
|
-
from .ogip import DataPHA, DataARF, DataRMF
|
|
13
12
|
from . import Observation, Instrument, ObsConfiguration
|
|
14
13
|
from ..model.abc import SpectralModel
|
|
15
14
|
from ..fit import CountForwardModel
|
|
@@ -26,17 +25,17 @@ def load_example_observations():
|
|
|
26
25
|
|
|
27
26
|
example_observations = {
|
|
28
27
|
"PN": Observation.from_pha_file(
|
|
29
|
-
importlib.resources.files("jaxspec") / "data/example_data/PN_spectrum_grp20.fits",
|
|
28
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/PN_spectrum_grp20.fits"),
|
|
30
29
|
low_energy=0.3,
|
|
31
30
|
high_energy=7.5,
|
|
32
31
|
),
|
|
33
32
|
"MOS1": Observation.from_pha_file(
|
|
34
|
-
importlib.resources.files("jaxspec") / "data/example_data/MOS1_spectrum_grp.fits",
|
|
33
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/MOS1_spectrum_grp.fits"),
|
|
35
34
|
low_energy=0.3,
|
|
36
35
|
high_energy=7,
|
|
37
36
|
),
|
|
38
37
|
"MOS2": Observation.from_pha_file(
|
|
39
|
-
importlib.resources.files("jaxspec") / "data/example_data/MOS2_spectrum_grp.fits",
|
|
38
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/MOS2_spectrum_grp.fits"),
|
|
40
39
|
low_energy=0.3,
|
|
41
40
|
high_energy=7,
|
|
42
41
|
),
|
|
@@ -52,16 +51,16 @@ def load_example_instruments():
|
|
|
52
51
|
|
|
53
52
|
example_instruments = {
|
|
54
53
|
"PN": Instrument.from_ogip_file(
|
|
55
|
-
importlib.resources.files("jaxspec") / "data/example_data/PN.
|
|
56
|
-
importlib.resources.files("jaxspec") / "data/example_data/PN.
|
|
54
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/PN.rmf"),
|
|
55
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/PN.arf"),
|
|
57
56
|
),
|
|
58
57
|
"MOS1": Instrument.from_ogip_file(
|
|
59
|
-
importlib.resources.files("jaxspec") / "data/example_data/MOS1.
|
|
60
|
-
importlib.resources.files("jaxspec") / "data/example_data/MOS1.
|
|
58
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.rmf"),
|
|
59
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.arf"),
|
|
61
60
|
),
|
|
62
61
|
"MOS2": Instrument.from_ogip_file(
|
|
63
|
-
importlib.resources.files("jaxspec") / "data/example_data/MOS2.
|
|
64
|
-
importlib.resources.files("jaxspec") / "data/example_data/MOS2.
|
|
62
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.rmf"),
|
|
63
|
+
str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.arf"),
|
|
65
64
|
),
|
|
66
65
|
}
|
|
67
66
|
|
|
@@ -202,46 +201,51 @@ def fakeit_for_multiple_parameters(
|
|
|
202
201
|
return fakeits[0] if len(fakeits) == 1 else fakeits
|
|
203
202
|
|
|
204
203
|
|
|
205
|
-
def
|
|
204
|
+
def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
|
|
206
205
|
"""
|
|
207
|
-
This function
|
|
208
|
-
from a given PHA file, using either the ARF/RMF/BKG filenames in the header or the
|
|
209
|
-
specified filenames overwritten by the user.
|
|
210
|
-
|
|
206
|
+
This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
|
|
211
207
|
Parameters:
|
|
212
208
|
pha_path: The PHA file path.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
213
211
|
arf_path: The ARF file path.
|
|
214
212
|
rmf_path: The RMF file path.
|
|
215
213
|
bkg_path: The BKG file path.
|
|
216
214
|
"""
|
|
217
215
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
216
|
+
def find_path(file_name: str, directory: str) -> str | None:
|
|
217
|
+
if file_name.lower() != "none" and file_name != "":
|
|
218
|
+
return find_file_or_compressed_in_dir(file_name, directory)
|
|
219
|
+
else:
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
header = fits.getheader(pha_path, "SPECTRUM")
|
|
223
|
+
directory = str(Path(pha_path).parent)
|
|
224
|
+
|
|
225
|
+
arf_path = find_path(header.get("ANCRFILE", "none"), directory)
|
|
226
|
+
rmf_path = find_path(header.get("RESPFILE", "none"), directory)
|
|
227
|
+
bkg_path = find_path(header.get("BACKFILE", "none"), directory)
|
|
228
|
+
|
|
229
|
+
return arf_path, rmf_path, bkg_path
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
|
|
233
|
+
"""
|
|
234
|
+
Try to find a file or its .gz compressed version in a given directory and return
|
|
235
|
+
the full path of the file.
|
|
236
|
+
"""
|
|
237
|
+
path = Path(path) if isinstance(path, str) else path
|
|
238
|
+
directory = Path(directory) if isinstance(directory, str) else directory
|
|
239
|
+
|
|
240
|
+
if directory.joinpath(path).exists():
|
|
241
|
+
return str(directory.joinpath(path))
|
|
242
|
+
|
|
243
|
+
matching_files = list(directory.glob(str(path) + "*"))
|
|
244
|
+
|
|
245
|
+
if matching_files:
|
|
246
|
+
file = matching_files[0]
|
|
247
|
+
if file.suffix == ".gz":
|
|
248
|
+
return str(file)
|
|
246
249
|
|
|
247
|
-
|
|
250
|
+
else:
|
|
251
|
+
raise FileNotFoundError(f"Can't find {path}(.gz) in {directory}.")
|
jaxspec/fit.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Callable, TypeVar
|
|
|
7
7
|
from abc import ABC
|
|
8
8
|
from jax import random
|
|
9
9
|
from jax.tree_util import tree_map
|
|
10
|
+
from jax.experimental.sparse import BCSR
|
|
10
11
|
from .analysis.results import ChainResult
|
|
11
12
|
from .model.abc import SpectralModel
|
|
12
13
|
from .data import ObsConfiguration
|
|
@@ -45,16 +46,19 @@ def build_numpyro_model(
|
|
|
45
46
|
model: SpectralModel,
|
|
46
47
|
background_model: BackgroundModel,
|
|
47
48
|
name: str = "",
|
|
49
|
+
sparse: bool = False,
|
|
48
50
|
) -> Callable:
|
|
49
51
|
def numpro_model(prior_params, observed=True):
|
|
50
52
|
# prior_params = build_prior(prior_distributions, name=name)
|
|
51
|
-
transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs)(par)))
|
|
53
|
+
transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par)))
|
|
52
54
|
|
|
53
55
|
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
54
|
-
# TODO : Raise warning when setting a background model but there is no background spectra loaded
|
|
55
56
|
bkg_countrate = background_model.numpyro_model(
|
|
56
57
|
obs.out_energies, obs.folded_background.data, name=name + "bkg", observed=observed
|
|
57
58
|
)
|
|
59
|
+
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
60
|
+
raise ValueError("Trying to fit a background model but no background is linked to this observation")
|
|
61
|
+
|
|
58
62
|
else:
|
|
59
63
|
bkg_countrate = 0.0
|
|
60
64
|
|
|
@@ -65,7 +69,7 @@ def build_numpyro_model(
|
|
|
65
69
|
with numpyro.plate(name + "obs_plate", len(obs.folded_counts)):
|
|
66
70
|
numpyro.sample(
|
|
67
71
|
name + "obs",
|
|
68
|
-
Poisson(countrate + bkg_countrate
|
|
72
|
+
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
69
73
|
obs=obs.folded_counts.data if observed else None,
|
|
70
74
|
)
|
|
71
75
|
|
|
@@ -77,11 +81,16 @@ class CountForwardModel(hk.Module):
|
|
|
77
81
|
A haiku module which allows to build the function that simulates the measured counts
|
|
78
82
|
"""
|
|
79
83
|
|
|
80
|
-
def __init__(self, model: SpectralModel, folding: ObsConfiguration):
|
|
84
|
+
def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False):
|
|
81
85
|
super().__init__()
|
|
82
86
|
self.model = model
|
|
83
87
|
self.energies = jnp.asarray(folding.in_energies)
|
|
84
|
-
|
|
88
|
+
|
|
89
|
+
if sparse: # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
90
|
+
self.transfer_matrix = BCSR.from_scipy_sparse(folding.transfer_matrix.data.to_scipy_sparse().tocsr()) #
|
|
91
|
+
|
|
92
|
+
else:
|
|
93
|
+
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
85
94
|
|
|
86
95
|
def __call__(self, parameters):
|
|
87
96
|
"""
|
|
@@ -119,7 +128,7 @@ class BayesianModelAbstract(ABC):
|
|
|
119
128
|
num_samples: int = 1000,
|
|
120
129
|
max_tree_depth: int = 10,
|
|
121
130
|
target_accept_prob: float = 0.8,
|
|
122
|
-
dense_mass=
|
|
131
|
+
dense_mass=False,
|
|
123
132
|
mcmc_kwargs: dict = {},
|
|
124
133
|
) -> ChainResult:
|
|
125
134
|
"""
|
|
@@ -184,13 +193,14 @@ class BayesianModelAbstract(ABC):
|
|
|
184
193
|
|
|
185
194
|
class BayesianModel(BayesianModelAbstract):
|
|
186
195
|
"""
|
|
187
|
-
Class to fit a model to a given
|
|
196
|
+
Class to fit a model to a given observation using a Bayesian approach.
|
|
188
197
|
"""
|
|
189
198
|
|
|
190
|
-
def __init__(self, model, observation, background_model: BackgroundModel = None):
|
|
199
|
+
def __init__(self, model, observation, background_model: BackgroundModel = None, sparsify_matrix: bool = False):
|
|
191
200
|
super().__init__(model)
|
|
192
201
|
self.observation = observation
|
|
193
202
|
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
203
|
+
self.sparse = sparsify_matrix
|
|
194
204
|
self.background_model = background_model
|
|
195
205
|
|
|
196
206
|
def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
|
|
@@ -204,7 +214,7 @@ class BayesianModel(BayesianModelAbstract):
|
|
|
204
214
|
|
|
205
215
|
def model(observed=True):
|
|
206
216
|
prior_params = build_prior(prior_distributions)
|
|
207
|
-
obs_model = build_numpyro_model(self.observation, self.model, self.background_model)
|
|
217
|
+
obs_model = build_numpyro_model(self.observation, self.model, self.background_model, sparse=self.sparse)
|
|
208
218
|
obs_model(prior_params, observed=observed)
|
|
209
219
|
|
|
210
220
|
return model
|