jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/__init__.py +0 -0
- jaxspec/_fit/_build_model.py +63 -0
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +238 -336
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +68 -11
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +101 -140
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +94 -87
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/METADATA +36 -16
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.3.dist-info/RECORD +0 -31
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/data/instrument.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
4
|
import xarray as xr
|
|
5
|
+
|
|
4
6
|
from matplotlib import colors
|
|
7
|
+
|
|
5
8
|
from .ogip import DataARF, DataRMF
|
|
6
9
|
|
|
7
10
|
|
|
@@ -25,7 +28,15 @@ class Instrument(xr.Dataset):
|
|
|
25
28
|
)
|
|
26
29
|
|
|
27
30
|
@classmethod
|
|
28
|
-
def from_matrix(
|
|
31
|
+
def from_matrix(
|
|
32
|
+
cls,
|
|
33
|
+
redistribution_matrix,
|
|
34
|
+
spectral_response,
|
|
35
|
+
e_min_unfolded,
|
|
36
|
+
e_max_unfolded,
|
|
37
|
+
e_min_channel,
|
|
38
|
+
e_max_channel,
|
|
39
|
+
):
|
|
29
40
|
return cls(
|
|
30
41
|
{
|
|
31
42
|
"redistribution": (
|
|
@@ -65,7 +76,7 @@ class Instrument(xr.Dataset):
|
|
|
65
76
|
)
|
|
66
77
|
|
|
67
78
|
@classmethod
|
|
68
|
-
def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike = None):
|
|
79
|
+
def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike | None = None):
|
|
69
80
|
"""
|
|
70
81
|
Load the data from OGIP files.
|
|
71
82
|
|
|
@@ -80,37 +91,61 @@ class Instrument(xr.Dataset):
|
|
|
80
91
|
specresp = DataARF.from_file(arf_path).specresp
|
|
81
92
|
|
|
82
93
|
else:
|
|
83
|
-
specresp =
|
|
94
|
+
specresp = rmf.matrix.sum(axis=0)
|
|
95
|
+
rmf.sparse_matrix /= specresp
|
|
84
96
|
|
|
85
|
-
return cls.from_matrix(
|
|
97
|
+
return cls.from_matrix(
|
|
98
|
+
rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
|
|
99
|
+
)
|
|
86
100
|
|
|
87
|
-
def plot_redistribution(
|
|
101
|
+
def plot_redistribution(
|
|
102
|
+
self,
|
|
103
|
+
xscale: str = "log",
|
|
104
|
+
yscale: str = "log",
|
|
105
|
+
cmap=None,
|
|
106
|
+
vmin: float = 1e-6,
|
|
107
|
+
vmax: float = 1e0,
|
|
108
|
+
add_labels: bool = True,
|
|
109
|
+
**kwargs,
|
|
110
|
+
):
|
|
88
111
|
"""
|
|
89
112
|
Plot the redistribution probability matrix
|
|
90
113
|
|
|
91
114
|
Parameters:
|
|
115
|
+
xscale : The scale of the x-axis.
|
|
116
|
+
yscale : The scale of the y-axis.
|
|
117
|
+
cmap : The colormap to use.
|
|
118
|
+
vmin : The minimum value for the colormap.
|
|
119
|
+
vmax : The maximum value for the colormap.
|
|
120
|
+
add_labels : Whether to add labels to the plot.
|
|
92
121
|
**kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.plot.pcolormesh.html#xarray.plot.pcolormesh
|
|
93
122
|
"""
|
|
123
|
+
|
|
94
124
|
import cmasher as cmr
|
|
95
125
|
|
|
96
126
|
return xr.plot.pcolormesh(
|
|
97
127
|
self.redistribution,
|
|
98
128
|
x="e_max_unfolded",
|
|
99
129
|
y="e_max_channel",
|
|
100
|
-
xscale=
|
|
101
|
-
yscale=
|
|
102
|
-
cmap=cmr.ember_r,
|
|
103
|
-
norm=colors.LogNorm(vmin=
|
|
104
|
-
add_labels=
|
|
130
|
+
xscale=xscale,
|
|
131
|
+
yscale=yscale,
|
|
132
|
+
cmap=cmr.ember_r if cmap is None else cmap,
|
|
133
|
+
norm=colors.LogNorm(vmin=vmin, vmax=vmax),
|
|
134
|
+
add_labels=add_labels,
|
|
105
135
|
**kwargs,
|
|
106
136
|
)
|
|
107
137
|
|
|
108
|
-
def plot_area(self, **kwargs):
|
|
138
|
+
def plot_area(self, xscale: str = "log", yscale: str = "log", where: str = "post", **kwargs):
|
|
109
139
|
"""
|
|
110
140
|
Plot the effective area
|
|
111
141
|
|
|
112
142
|
Parameters:
|
|
143
|
+
xscale : The scale of the x-axis.
|
|
144
|
+
yscale : The scale of the y-axis.
|
|
145
|
+
where : The position of the steps.
|
|
113
146
|
**kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.DataArray.plot.line.html#xarray.DataArray.plot.line
|
|
114
147
|
"""
|
|
115
148
|
|
|
116
|
-
return self.area.plot.step(
|
|
149
|
+
return self.area.plot.step(
|
|
150
|
+
x="e_min_unfolded", xscale=xscale, yscale=yscale, where=where, **kwargs
|
|
151
|
+
)
|
jaxspec/data/obsconf.py
CHANGED
|
@@ -157,8 +157,10 @@ class ObsConfiguration(xr.Dataset):
|
|
|
157
157
|
|
|
158
158
|
if observation.folded_background is not None:
|
|
159
159
|
folded_background = observation.folded_background.data[row_idx]
|
|
160
|
+
folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
|
|
160
161
|
else:
|
|
161
162
|
folded_background = np.zeros_like(folded_counts)
|
|
163
|
+
folded_background_unscaled = np.zeros_like(folded_counts)
|
|
162
164
|
|
|
163
165
|
data_dict = {
|
|
164
166
|
"transfer_matrix": (
|
|
@@ -200,6 +202,14 @@ class ObsConfiguration(xr.Dataset):
|
|
|
200
202
|
"unit": "photons",
|
|
201
203
|
},
|
|
202
204
|
),
|
|
205
|
+
"folded_background_unscaled": (
|
|
206
|
+
["folded_channel"],
|
|
207
|
+
folded_background_unscaled,
|
|
208
|
+
{
|
|
209
|
+
"description": "To be done",
|
|
210
|
+
"unit": "photons",
|
|
211
|
+
},
|
|
212
|
+
),
|
|
203
213
|
}
|
|
204
214
|
|
|
205
215
|
return cls(
|
|
@@ -234,8 +244,8 @@ class ObsConfiguration(xr.Dataset):
|
|
|
234
244
|
cls,
|
|
235
245
|
instrument: Instrument,
|
|
236
246
|
exposure: float,
|
|
237
|
-
low_energy: float = 1e-
|
|
238
|
-
high_energy: float =
|
|
247
|
+
low_energy: float = 1e-300,
|
|
248
|
+
high_energy: float = 1e300,
|
|
239
249
|
):
|
|
240
250
|
"""
|
|
241
251
|
Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
|
jaxspec/data/observation.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import xarray as xr
|
|
3
|
+
|
|
3
4
|
from .ogip import DataPHA
|
|
4
5
|
|
|
5
6
|
|
|
@@ -23,7 +24,16 @@ class Observation(xr.Dataset):
|
|
|
23
24
|
folded_background: xr.DataArray
|
|
24
25
|
"""The background counts, after grouping"""
|
|
25
26
|
|
|
26
|
-
__slots__ = (
|
|
27
|
+
__slots__ = (
|
|
28
|
+
"grouping",
|
|
29
|
+
"channel",
|
|
30
|
+
"quality",
|
|
31
|
+
"exposure",
|
|
32
|
+
"background",
|
|
33
|
+
"folded_background",
|
|
34
|
+
"counts",
|
|
35
|
+
"folded_counts",
|
|
36
|
+
)
|
|
27
37
|
|
|
28
38
|
_default_attributes = {"description": "X-ray observation dataset"}
|
|
29
39
|
|
|
@@ -36,17 +46,23 @@ class Observation(xr.Dataset):
|
|
|
36
46
|
quality,
|
|
37
47
|
exposure,
|
|
38
48
|
background=None,
|
|
49
|
+
background_unscaled=None,
|
|
39
50
|
backratio=1.0,
|
|
40
51
|
attributes: dict | None = None,
|
|
41
52
|
):
|
|
42
53
|
if attributes is None:
|
|
43
54
|
attributes = {}
|
|
44
55
|
|
|
45
|
-
if background is None:
|
|
56
|
+
if background is None or background_unscaled is None:
|
|
46
57
|
background = np.zeros_like(counts, dtype=np.int64)
|
|
58
|
+
background_unscaled = np.zeros_like(counts, dtype=np.int64)
|
|
47
59
|
|
|
48
60
|
data_dict = {
|
|
49
|
-
"counts": (
|
|
61
|
+
"counts": (
|
|
62
|
+
["instrument_channel"],
|
|
63
|
+
np.asarray(counts, dtype=np.int64),
|
|
64
|
+
{"description": "Counts", "unit": "photons"},
|
|
65
|
+
),
|
|
50
66
|
"folded_counts": (
|
|
51
67
|
["folded_channel"],
|
|
52
68
|
np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64),
|
|
@@ -57,7 +73,11 @@ class Observation(xr.Dataset):
|
|
|
57
73
|
grouping,
|
|
58
74
|
{"description": "Grouping matrix."},
|
|
59
75
|
),
|
|
60
|
-
"quality": (
|
|
76
|
+
"quality": (
|
|
77
|
+
["instrument_channel"],
|
|
78
|
+
np.asarray(quality, dtype=np.int64),
|
|
79
|
+
{"description": "Quality flag."},
|
|
80
|
+
),
|
|
61
81
|
"exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}),
|
|
62
82
|
"backratio": (
|
|
63
83
|
["instrument_channel"],
|
|
@@ -74,9 +94,14 @@ class Observation(xr.Dataset):
|
|
|
74
94
|
np.asarray(background, dtype=np.int64),
|
|
75
95
|
{"description": "Background counts", "unit": "photons"},
|
|
76
96
|
),
|
|
97
|
+
"folded_background_unscaled": (
|
|
98
|
+
["folded_channel"],
|
|
99
|
+
np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
|
|
100
|
+
{"description": "Background counts", "unit": "photons"},
|
|
101
|
+
),
|
|
77
102
|
"folded_background": (
|
|
78
103
|
["folded_channel"],
|
|
79
|
-
np.asarray(np.ma.filled(grouping @ background), dtype=np.
|
|
104
|
+
np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
|
|
80
105
|
{"description": "Background counts", "unit": "photons"},
|
|
81
106
|
),
|
|
82
107
|
}
|
|
@@ -84,36 +109,59 @@ class Observation(xr.Dataset):
|
|
|
84
109
|
return cls(
|
|
85
110
|
data_dict,
|
|
86
111
|
coords={
|
|
87
|
-
"channel": (
|
|
112
|
+
"channel": (
|
|
113
|
+
["instrument_channel"],
|
|
114
|
+
np.asarray(channel, dtype=np.int64),
|
|
115
|
+
{"description": "Channel number"},
|
|
116
|
+
),
|
|
88
117
|
"grouped_channel": (
|
|
89
118
|
["folded_channel"],
|
|
90
119
|
np.arange(len(grouping @ counts), dtype=np.int64),
|
|
91
120
|
{"description": "Channel number"},
|
|
92
121
|
),
|
|
93
122
|
},
|
|
94
|
-
attrs=cls._default_attributes
|
|
123
|
+
attrs=cls._default_attributes
|
|
124
|
+
if attributes is None
|
|
125
|
+
else attributes | cls._default_attributes,
|
|
95
126
|
)
|
|
96
127
|
|
|
97
128
|
@classmethod
|
|
98
129
|
def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
|
|
99
130
|
if bkg is not None:
|
|
100
|
-
backratio = np.nan_to_num(
|
|
131
|
+
backratio = np.nan_to_num(
|
|
132
|
+
(pha.backscal * pha.exposure * pha.areascal)
|
|
133
|
+
/ (bkg.backscal * bkg.exposure * bkg.areascal)
|
|
134
|
+
)
|
|
101
135
|
else:
|
|
102
136
|
backratio = np.ones_like(pha.counts)
|
|
103
137
|
|
|
138
|
+
if (bkg is not None) and ("NET" in pha.flags):
|
|
139
|
+
counts = pha.counts + bkg.counts * backratio
|
|
140
|
+
else:
|
|
141
|
+
counts = pha.counts
|
|
142
|
+
|
|
104
143
|
return cls.from_matrix(
|
|
105
|
-
|
|
144
|
+
counts,
|
|
106
145
|
pha.grouping,
|
|
107
146
|
pha.channel,
|
|
108
147
|
pha.quality,
|
|
109
148
|
pha.exposure,
|
|
110
149
|
backratio=backratio,
|
|
111
|
-
background=bkg.counts if bkg is not None else None,
|
|
150
|
+
background=bkg.counts * backratio if bkg is not None else None,
|
|
151
|
+
background_unscaled=bkg.counts if bkg is not None else None,
|
|
112
152
|
attributes=metadata,
|
|
113
153
|
)
|
|
114
154
|
|
|
115
155
|
@classmethod
|
|
116
156
|
def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
|
|
157
|
+
"""
|
|
158
|
+
Build an observation from a PHA file
|
|
159
|
+
|
|
160
|
+
Parameters:
|
|
161
|
+
pha_path : Path to the PHA file
|
|
162
|
+
bkg_path : Path to the background file
|
|
163
|
+
metadata : Additional metadata to add to the observation
|
|
164
|
+
"""
|
|
117
165
|
from .util import data_path_finder
|
|
118
166
|
|
|
119
167
|
arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
|
|
@@ -155,7 +203,16 @@ class Observation(xr.Dataset):
|
|
|
155
203
|
|
|
156
204
|
fig = plt.figure(figsize=(6, 6))
|
|
157
205
|
gs = fig.add_gridspec(
|
|
158
|
-
2,
|
|
206
|
+
2,
|
|
207
|
+
2,
|
|
208
|
+
width_ratios=(4, 1),
|
|
209
|
+
height_ratios=(1, 4),
|
|
210
|
+
left=0.1,
|
|
211
|
+
right=0.9,
|
|
212
|
+
bottom=0.1,
|
|
213
|
+
top=0.9,
|
|
214
|
+
wspace=0.05,
|
|
215
|
+
hspace=0.05,
|
|
159
216
|
)
|
|
160
217
|
ax = fig.add_subplot(gs[1, 0])
|
|
161
218
|
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
|
jaxspec/data/ogip.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import os
|
|
2
|
+
|
|
3
3
|
import astropy.units as u
|
|
4
|
+
import numpy as np
|
|
4
5
|
import sparse
|
|
5
|
-
|
|
6
|
+
|
|
6
7
|
from astropy.io import fits
|
|
8
|
+
from astropy.table import QTable
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class DataPHA:
|
|
@@ -25,6 +27,7 @@ class DataPHA:
|
|
|
25
27
|
ancrfile=None,
|
|
26
28
|
backscal=1.0,
|
|
27
29
|
areascal=1.0,
|
|
30
|
+
flags=None,
|
|
28
31
|
):
|
|
29
32
|
self.channel = np.asarray(channel, dtype=int)
|
|
30
33
|
self.counts = np.asarray(counts, dtype=int)
|
|
@@ -36,6 +39,7 @@ class DataPHA:
|
|
|
36
39
|
self.ancrfile = ancrfile
|
|
37
40
|
self.backscal = np.asarray(backscal, dtype=float)
|
|
38
41
|
self.areascal = np.asarray(areascal, dtype=float)
|
|
42
|
+
self.flags = flags
|
|
39
43
|
|
|
40
44
|
if grouping is not None:
|
|
41
45
|
# Indices array of the beginning of each group
|
|
@@ -55,7 +59,9 @@ class DataPHA:
|
|
|
55
59
|
data.append(True)
|
|
56
60
|
|
|
57
61
|
# Create a COO sparse matrix
|
|
58
|
-
grp_matrix = sparse.COO(
|
|
62
|
+
grp_matrix = sparse.COO(
|
|
63
|
+
(data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0
|
|
64
|
+
)
|
|
59
65
|
|
|
60
66
|
else:
|
|
61
67
|
# Identity matrix case, use sparse for efficiency
|
|
@@ -74,12 +80,7 @@ class DataPHA:
|
|
|
74
80
|
|
|
75
81
|
data = QTable.read(pha_file, "SPECTRUM")
|
|
76
82
|
header = fits.getheader(pha_file, "SPECTRUM")
|
|
77
|
-
|
|
78
|
-
if header.get("HDUCLAS2") == "NET":
|
|
79
|
-
raise ValueError(
|
|
80
|
-
f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
|
|
81
|
-
f"Please open an issue if this is required."
|
|
82
|
-
)
|
|
83
|
+
flags = []
|
|
83
84
|
|
|
84
85
|
if header.get("HDUCLAS3") == "RATE":
|
|
85
86
|
raise ValueError(
|
|
@@ -114,6 +115,8 @@ class DataPHA:
|
|
|
114
115
|
else:
|
|
115
116
|
raise ValueError("No BACKSCAL found in the PHA file.")
|
|
116
117
|
|
|
118
|
+
backscal = np.where(backscal == 0, 1.0, backscal)
|
|
119
|
+
|
|
117
120
|
if "AREASCAL" in header:
|
|
118
121
|
areascal = header["AREASCAL"]
|
|
119
122
|
elif "AREASCAL" in data.colnames:
|
|
@@ -121,8 +124,9 @@ class DataPHA:
|
|
|
121
124
|
else:
|
|
122
125
|
raise ValueError("No AREASCAL found in the PHA file.")
|
|
123
126
|
|
|
124
|
-
|
|
125
|
-
|
|
127
|
+
if header.get("HDUCLAS2") == "NET":
|
|
128
|
+
flags.append("NET")
|
|
129
|
+
|
|
126
130
|
kwargs = {
|
|
127
131
|
"grouping": grouping,
|
|
128
132
|
"quality": quality,
|
|
@@ -131,6 +135,7 @@ class DataPHA:
|
|
|
131
135
|
"ancrfile": header.get("ANCRFILE"),
|
|
132
136
|
"backscal": backscal,
|
|
133
137
|
"areascal": areascal,
|
|
138
|
+
"flags": flags,
|
|
134
139
|
}
|
|
135
140
|
|
|
136
141
|
return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
|
|
@@ -176,7 +181,19 @@ class DataRMF:
|
|
|
176
181
|
* [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
|
|
177
182
|
"""
|
|
178
183
|
|
|
179
|
-
def __init__(
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
energ_lo,
|
|
187
|
+
energ_hi,
|
|
188
|
+
n_grp,
|
|
189
|
+
f_chan,
|
|
190
|
+
n_chan,
|
|
191
|
+
matrix,
|
|
192
|
+
channel,
|
|
193
|
+
e_min,
|
|
194
|
+
e_max,
|
|
195
|
+
low_threshold=0.0,
|
|
196
|
+
):
|
|
180
197
|
# RMF stuff
|
|
181
198
|
self.energ_lo = energ_lo # "Entry" energies
|
|
182
199
|
self.energ_hi = energ_hi # "Entry" energies
|
|
@@ -229,7 +246,9 @@ class DataRMF:
|
|
|
229
246
|
idxs = data > low_threshold
|
|
230
247
|
|
|
231
248
|
# Create a COO sparse matrix and then convert to CSR for efficiency
|
|
232
|
-
coo = sparse.COO(
|
|
249
|
+
coo = sparse.COO(
|
|
250
|
+
[rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel))
|
|
251
|
+
)
|
|
233
252
|
self.sparse_matrix = coo.T # .tocsr()
|
|
234
253
|
|
|
235
254
|
@property
|
jaxspec/data/util.py
CHANGED
|
@@ -2,16 +2,13 @@ from collections.abc import Mapping
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Literal, TypeVar
|
|
4
4
|
|
|
5
|
-
import haiku as hk
|
|
6
5
|
import jax
|
|
7
|
-
import numpy as np
|
|
8
6
|
import numpyro
|
|
9
7
|
|
|
10
8
|
from astropy.io import fits
|
|
11
|
-
from numpy.typing import ArrayLike
|
|
12
9
|
from numpyro import handlers
|
|
13
10
|
|
|
14
|
-
from ..
|
|
11
|
+
from .._fit._build_model import forward_model
|
|
15
12
|
from ..model.abc import SpectralModel
|
|
16
13
|
from ..util.online_storage import table_manager
|
|
17
14
|
from . import Instrument, ObsConfiguration, Observation
|
|
@@ -127,67 +124,6 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
|
|
|
127
124
|
raise ValueError(f"{source} not recognized.")
|
|
128
125
|
|
|
129
126
|
|
|
130
|
-
def fakeit(
|
|
131
|
-
instrument: ObsConfiguration | list[ObsConfiguration],
|
|
132
|
-
model: SpectralModel,
|
|
133
|
-
parameters: Mapping[K, V],
|
|
134
|
-
rng_key: int = 0,
|
|
135
|
-
sparsify_matrix: bool = False,
|
|
136
|
-
) -> ArrayLike | list[ArrayLike]:
|
|
137
|
-
"""
|
|
138
|
-
Convenience function to simulate a spectrum from a given model and a set of parameters.
|
|
139
|
-
It requires an instrumental setup, and unlike in
|
|
140
|
-
[XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
|
|
141
|
-
exclusively by Poisson statistics.
|
|
142
|
-
|
|
143
|
-
Parameters:
|
|
144
|
-
instrument: The instrumental setup.
|
|
145
|
-
model: The model to use.
|
|
146
|
-
parameters: The parameters of the model.
|
|
147
|
-
rng_key: The random number generator seed.
|
|
148
|
-
sparsify_matrix: Whether to sparsify the matrix or not.
|
|
149
|
-
"""
|
|
150
|
-
|
|
151
|
-
instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
|
|
152
|
-
fakeits = []
|
|
153
|
-
|
|
154
|
-
for i, instrument in enumerate(instruments):
|
|
155
|
-
transformed_model = hk.without_apply_rng(
|
|
156
|
-
hk.transform(
|
|
157
|
-
lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
|
|
158
|
-
)
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
def obs_model(p):
|
|
162
|
-
return transformed_model.apply(None, p)
|
|
163
|
-
|
|
164
|
-
with handlers.seed(rng_seed=rng_key):
|
|
165
|
-
counts = numpyro.sample(
|
|
166
|
-
f"likelihood_obs_{i}",
|
|
167
|
-
numpyro.distributions.Poisson(obs_model(parameters)),
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
"""
|
|
171
|
-
pha = DataPHA(
|
|
172
|
-
instrument.rmf.channel,
|
|
173
|
-
np.array(counts, dtype=int)*u.ct,
|
|
174
|
-
instrument.exposure,
|
|
175
|
-
grouping=instrument.grouping)
|
|
176
|
-
|
|
177
|
-
observation = Observation(
|
|
178
|
-
pha=pha,
|
|
179
|
-
arf=instrument.arf,
|
|
180
|
-
rmf=instrument.rmf,
|
|
181
|
-
low_energy=instrument.low_energy,
|
|
182
|
-
high_energy=instrument.high_energy
|
|
183
|
-
)
|
|
184
|
-
"""
|
|
185
|
-
|
|
186
|
-
fakeits.append(np.array(counts, dtype=int))
|
|
187
|
-
|
|
188
|
-
return fakeits[0] if len(fakeits) == 1 else fakeits
|
|
189
|
-
|
|
190
|
-
|
|
191
127
|
def fakeit_for_multiple_parameters(
|
|
192
128
|
instrument: ObsConfiguration | list[ObsConfiguration],
|
|
193
129
|
model: SpectralModel,
|
|
@@ -199,7 +135,6 @@ def fakeit_for_multiple_parameters(
|
|
|
199
135
|
"""
|
|
200
136
|
Convenience function to simulate multiple spectra from a given model and a set of parameters.
|
|
201
137
|
|
|
202
|
-
TODO : avoid redundancy, better doc and type hints
|
|
203
138
|
|
|
204
139
|
Parameters:
|
|
205
140
|
instrument: The instrumental setup.
|
|
@@ -214,24 +149,19 @@ def fakeit_for_multiple_parameters(
|
|
|
214
149
|
fakeits = []
|
|
215
150
|
|
|
216
151
|
for i, obs in enumerate(instruments):
|
|
217
|
-
|
|
218
|
-
|
|
152
|
+
countrate = jax.vmap(lambda p: forward_model(model, p, instrument, sparse=sparsify_matrix))(
|
|
153
|
+
parameters
|
|
219
154
|
)
|
|
220
155
|
|
|
221
|
-
@jax.jit
|
|
222
|
-
@jax.vmap
|
|
223
|
-
def obs_model(p):
|
|
224
|
-
return transformed_model.apply(None, p)
|
|
225
|
-
|
|
226
156
|
if apply_stat:
|
|
227
157
|
with handlers.seed(rng_seed=rng_key):
|
|
228
158
|
spectrum = numpyro.sample(
|
|
229
159
|
f"likelihood_obs_{i}",
|
|
230
|
-
numpyro.distributions.Poisson(
|
|
160
|
+
numpyro.distributions.Poisson(countrate),
|
|
231
161
|
)
|
|
232
162
|
|
|
233
163
|
else:
|
|
234
|
-
spectrum =
|
|
164
|
+
spectrum = countrate
|
|
235
165
|
|
|
236
166
|
fakeits.append(spectrum)
|
|
237
167
|
|