jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/_build_model.py +26 -103
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +231 -332
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +17 -4
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +60 -70
- jaxspec/fit.py +76 -44
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +3 -4
- jaxspec/model/multiplicative.py +102 -86
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/METADATA +13 -14
- jaxspec-0.2.1.dist-info/RECORD +34 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.4.dist-info/RECORD +0 -33
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/entry_points.txt +0 -0
jaxspec/data/instrument.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
4
|
import xarray as xr
|
|
5
|
+
|
|
4
6
|
from matplotlib import colors
|
|
7
|
+
|
|
5
8
|
from .ogip import DataARF, DataRMF
|
|
6
9
|
|
|
7
10
|
|
|
@@ -25,7 +28,15 @@ class Instrument(xr.Dataset):
|
|
|
25
28
|
)
|
|
26
29
|
|
|
27
30
|
@classmethod
|
|
28
|
-
def from_matrix(
|
|
31
|
+
def from_matrix(
|
|
32
|
+
cls,
|
|
33
|
+
redistribution_matrix,
|
|
34
|
+
spectral_response,
|
|
35
|
+
e_min_unfolded,
|
|
36
|
+
e_max_unfolded,
|
|
37
|
+
e_min_channel,
|
|
38
|
+
e_max_channel,
|
|
39
|
+
):
|
|
29
40
|
return cls(
|
|
30
41
|
{
|
|
31
42
|
"redistribution": (
|
|
@@ -65,7 +76,7 @@ class Instrument(xr.Dataset):
|
|
|
65
76
|
)
|
|
66
77
|
|
|
67
78
|
@classmethod
|
|
68
|
-
def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike = None):
|
|
79
|
+
def from_ogip_file(cls, rmf_path: str | os.PathLike, arf_path: str | os.PathLike | None = None):
|
|
69
80
|
"""
|
|
70
81
|
Load the data from OGIP files.
|
|
71
82
|
|
|
@@ -80,37 +91,61 @@ class Instrument(xr.Dataset):
|
|
|
80
91
|
specresp = DataARF.from_file(arf_path).specresp
|
|
81
92
|
|
|
82
93
|
else:
|
|
83
|
-
specresp =
|
|
94
|
+
specresp = rmf.matrix.sum(axis=0)
|
|
95
|
+
rmf.sparse_matrix /= specresp
|
|
84
96
|
|
|
85
|
-
return cls.from_matrix(
|
|
97
|
+
return cls.from_matrix(
|
|
98
|
+
rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
|
|
99
|
+
)
|
|
86
100
|
|
|
87
|
-
def plot_redistribution(
|
|
101
|
+
def plot_redistribution(
|
|
102
|
+
self,
|
|
103
|
+
xscale: str = "log",
|
|
104
|
+
yscale: str = "log",
|
|
105
|
+
cmap=None,
|
|
106
|
+
vmin: float = 1e-6,
|
|
107
|
+
vmax: float = 1e0,
|
|
108
|
+
add_labels: bool = True,
|
|
109
|
+
**kwargs,
|
|
110
|
+
):
|
|
88
111
|
"""
|
|
89
112
|
Plot the redistribution probability matrix
|
|
90
113
|
|
|
91
114
|
Parameters:
|
|
115
|
+
xscale : The scale of the x-axis.
|
|
116
|
+
yscale : The scale of the y-axis.
|
|
117
|
+
cmap : The colormap to use.
|
|
118
|
+
vmin : The minimum value for the colormap.
|
|
119
|
+
vmax : The maximum value for the colormap.
|
|
120
|
+
add_labels : Whether to add labels to the plot.
|
|
92
121
|
**kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.plot.pcolormesh.html#xarray.plot.pcolormesh
|
|
93
122
|
"""
|
|
123
|
+
|
|
94
124
|
import cmasher as cmr
|
|
95
125
|
|
|
96
126
|
return xr.plot.pcolormesh(
|
|
97
127
|
self.redistribution,
|
|
98
128
|
x="e_max_unfolded",
|
|
99
129
|
y="e_max_channel",
|
|
100
|
-
xscale=
|
|
101
|
-
yscale=
|
|
102
|
-
cmap=cmr.ember_r,
|
|
103
|
-
norm=colors.LogNorm(vmin=
|
|
104
|
-
add_labels=
|
|
130
|
+
xscale=xscale,
|
|
131
|
+
yscale=yscale,
|
|
132
|
+
cmap=cmr.ember_r if cmap is None else cmap,
|
|
133
|
+
norm=colors.LogNorm(vmin=vmin, vmax=vmax),
|
|
134
|
+
add_labels=add_labels,
|
|
105
135
|
**kwargs,
|
|
106
136
|
)
|
|
107
137
|
|
|
108
|
-
def plot_area(self, **kwargs):
|
|
138
|
+
def plot_area(self, xscale: str = "log", yscale: str = "log", where: str = "post", **kwargs):
|
|
109
139
|
"""
|
|
110
140
|
Plot the effective area
|
|
111
141
|
|
|
112
142
|
Parameters:
|
|
143
|
+
xscale : The scale of the x-axis.
|
|
144
|
+
yscale : The scale of the y-axis.
|
|
145
|
+
where : The position of the steps.
|
|
113
146
|
**kwargs : `kwargs` passed to https://docs.xarray.dev/en/latest/generated/xarray.DataArray.plot.line.html#xarray.DataArray.plot.line
|
|
114
147
|
"""
|
|
115
148
|
|
|
116
|
-
return self.area.plot.step(
|
|
149
|
+
return self.area.plot.step(
|
|
150
|
+
x="e_min_unfolded", xscale=xscale, yscale=yscale, where=where, **kwargs
|
|
151
|
+
)
|
jaxspec/data/obsconf.py
CHANGED
|
@@ -157,8 +157,10 @@ class ObsConfiguration(xr.Dataset):
|
|
|
157
157
|
|
|
158
158
|
if observation.folded_background is not None:
|
|
159
159
|
folded_background = observation.folded_background.data[row_idx]
|
|
160
|
+
folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
|
|
160
161
|
else:
|
|
161
162
|
folded_background = np.zeros_like(folded_counts)
|
|
163
|
+
folded_background_unscaled = np.zeros_like(folded_counts)
|
|
162
164
|
|
|
163
165
|
data_dict = {
|
|
164
166
|
"transfer_matrix": (
|
|
@@ -200,6 +202,14 @@ class ObsConfiguration(xr.Dataset):
|
|
|
200
202
|
"unit": "photons",
|
|
201
203
|
},
|
|
202
204
|
),
|
|
205
|
+
"folded_background_unscaled": (
|
|
206
|
+
["folded_channel"],
|
|
207
|
+
folded_background_unscaled,
|
|
208
|
+
{
|
|
209
|
+
"description": "To be done",
|
|
210
|
+
"unit": "photons",
|
|
211
|
+
},
|
|
212
|
+
),
|
|
203
213
|
}
|
|
204
214
|
|
|
205
215
|
return cls(
|
|
@@ -234,8 +244,8 @@ class ObsConfiguration(xr.Dataset):
|
|
|
234
244
|
cls,
|
|
235
245
|
instrument: Instrument,
|
|
236
246
|
exposure: float,
|
|
237
|
-
low_energy: float = 1e-
|
|
238
|
-
high_energy: float =
|
|
247
|
+
low_energy: float = 1e-300,
|
|
248
|
+
high_energy: float = 1e300,
|
|
239
249
|
):
|
|
240
250
|
"""
|
|
241
251
|
Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
|
jaxspec/data/observation.py
CHANGED
|
@@ -46,14 +46,16 @@ class Observation(xr.Dataset):
|
|
|
46
46
|
quality,
|
|
47
47
|
exposure,
|
|
48
48
|
background=None,
|
|
49
|
+
background_unscaled=None,
|
|
49
50
|
backratio=1.0,
|
|
50
51
|
attributes: dict | None = None,
|
|
51
52
|
):
|
|
52
53
|
if attributes is None:
|
|
53
54
|
attributes = {}
|
|
54
55
|
|
|
55
|
-
if background is None:
|
|
56
|
+
if background is None or background_unscaled is None:
|
|
56
57
|
background = np.zeros_like(counts, dtype=np.int64)
|
|
58
|
+
background_unscaled = np.zeros_like(counts, dtype=np.int64)
|
|
57
59
|
|
|
58
60
|
data_dict = {
|
|
59
61
|
"counts": (
|
|
@@ -92,9 +94,14 @@ class Observation(xr.Dataset):
|
|
|
92
94
|
np.asarray(background, dtype=np.int64),
|
|
93
95
|
{"description": "Background counts", "unit": "photons"},
|
|
94
96
|
),
|
|
97
|
+
"folded_background_unscaled": (
|
|
98
|
+
["folded_channel"],
|
|
99
|
+
np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
|
|
100
|
+
{"description": "Background counts", "unit": "photons"},
|
|
101
|
+
),
|
|
95
102
|
"folded_background": (
|
|
96
103
|
["folded_channel"],
|
|
97
|
-
np.asarray(np.ma.filled(grouping @ background), dtype=np.
|
|
104
|
+
np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
|
|
98
105
|
{"description": "Background counts", "unit": "photons"},
|
|
99
106
|
),
|
|
100
107
|
}
|
|
@@ -128,14 +135,20 @@ class Observation(xr.Dataset):
|
|
|
128
135
|
else:
|
|
129
136
|
backratio = np.ones_like(pha.counts)
|
|
130
137
|
|
|
138
|
+
if (bkg is not None) and ("NET" in pha.flags):
|
|
139
|
+
counts = pha.counts + bkg.counts * backratio
|
|
140
|
+
else:
|
|
141
|
+
counts = pha.counts
|
|
142
|
+
|
|
131
143
|
return cls.from_matrix(
|
|
132
|
-
|
|
144
|
+
counts,
|
|
133
145
|
pha.grouping,
|
|
134
146
|
pha.channel,
|
|
135
147
|
pha.quality,
|
|
136
148
|
pha.exposure,
|
|
137
149
|
backratio=backratio,
|
|
138
|
-
background=bkg.counts if bkg is not None else None,
|
|
150
|
+
background=bkg.counts * backratio if bkg is not None else None,
|
|
151
|
+
background_unscaled=bkg.counts if bkg is not None else None,
|
|
139
152
|
attributes=metadata,
|
|
140
153
|
)
|
|
141
154
|
|
jaxspec/data/ogip.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import os
|
|
2
|
+
|
|
3
3
|
import astropy.units as u
|
|
4
|
+
import numpy as np
|
|
4
5
|
import sparse
|
|
5
|
-
|
|
6
|
+
|
|
6
7
|
from astropy.io import fits
|
|
8
|
+
from astropy.table import QTable
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class DataPHA:
|
|
@@ -25,6 +27,7 @@ class DataPHA:
|
|
|
25
27
|
ancrfile=None,
|
|
26
28
|
backscal=1.0,
|
|
27
29
|
areascal=1.0,
|
|
30
|
+
flags=None,
|
|
28
31
|
):
|
|
29
32
|
self.channel = np.asarray(channel, dtype=int)
|
|
30
33
|
self.counts = np.asarray(counts, dtype=int)
|
|
@@ -36,6 +39,7 @@ class DataPHA:
|
|
|
36
39
|
self.ancrfile = ancrfile
|
|
37
40
|
self.backscal = np.asarray(backscal, dtype=float)
|
|
38
41
|
self.areascal = np.asarray(areascal, dtype=float)
|
|
42
|
+
self.flags = flags
|
|
39
43
|
|
|
40
44
|
if grouping is not None:
|
|
41
45
|
# Indices array of the beginning of each group
|
|
@@ -55,7 +59,9 @@ class DataPHA:
|
|
|
55
59
|
data.append(True)
|
|
56
60
|
|
|
57
61
|
# Create a COO sparse matrix
|
|
58
|
-
grp_matrix = sparse.COO(
|
|
62
|
+
grp_matrix = sparse.COO(
|
|
63
|
+
(data, (rows, cols)), shape=(len(b_grp), len(channel)), fill_value=0
|
|
64
|
+
)
|
|
59
65
|
|
|
60
66
|
else:
|
|
61
67
|
# Identity matrix case, use sparse for efficiency
|
|
@@ -74,12 +80,7 @@ class DataPHA:
|
|
|
74
80
|
|
|
75
81
|
data = QTable.read(pha_file, "SPECTRUM")
|
|
76
82
|
header = fits.getheader(pha_file, "SPECTRUM")
|
|
77
|
-
|
|
78
|
-
if header.get("HDUCLAS2") == "NET":
|
|
79
|
-
raise ValueError(
|
|
80
|
-
f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
|
|
81
|
-
f"Please open an issue if this is required."
|
|
82
|
-
)
|
|
83
|
+
flags = []
|
|
83
84
|
|
|
84
85
|
if header.get("HDUCLAS3") == "RATE":
|
|
85
86
|
raise ValueError(
|
|
@@ -114,6 +115,8 @@ class DataPHA:
|
|
|
114
115
|
else:
|
|
115
116
|
raise ValueError("No BACKSCAL found in the PHA file.")
|
|
116
117
|
|
|
118
|
+
backscal = np.where(backscal == 0, 1.0, backscal)
|
|
119
|
+
|
|
117
120
|
if "AREASCAL" in header:
|
|
118
121
|
areascal = header["AREASCAL"]
|
|
119
122
|
elif "AREASCAL" in data.colnames:
|
|
@@ -121,8 +124,9 @@ class DataPHA:
|
|
|
121
124
|
else:
|
|
122
125
|
raise ValueError("No AREASCAL found in the PHA file.")
|
|
123
126
|
|
|
124
|
-
|
|
125
|
-
|
|
127
|
+
if header.get("HDUCLAS2") == "NET":
|
|
128
|
+
flags.append("NET")
|
|
129
|
+
|
|
126
130
|
kwargs = {
|
|
127
131
|
"grouping": grouping,
|
|
128
132
|
"quality": quality,
|
|
@@ -131,6 +135,7 @@ class DataPHA:
|
|
|
131
135
|
"ancrfile": header.get("ANCRFILE"),
|
|
132
136
|
"backscal": backscal,
|
|
133
137
|
"areascal": areascal,
|
|
138
|
+
"flags": flags,
|
|
134
139
|
}
|
|
135
140
|
|
|
136
141
|
return cls(data["CHANNEL"], data["COUNTS"], header["EXPOSURE"], **kwargs)
|
|
@@ -176,7 +181,19 @@ class DataRMF:
|
|
|
176
181
|
* [The Calibration Requirements for Spectral Analysis Addendum: Changes log](https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002a/cal_gen_92_002a.html)
|
|
177
182
|
"""
|
|
178
183
|
|
|
179
|
-
def __init__(
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
energ_lo,
|
|
187
|
+
energ_hi,
|
|
188
|
+
n_grp,
|
|
189
|
+
f_chan,
|
|
190
|
+
n_chan,
|
|
191
|
+
matrix,
|
|
192
|
+
channel,
|
|
193
|
+
e_min,
|
|
194
|
+
e_max,
|
|
195
|
+
low_threshold=0.0,
|
|
196
|
+
):
|
|
180
197
|
# RMF stuff
|
|
181
198
|
self.energ_lo = energ_lo # "Entry" energies
|
|
182
199
|
self.energ_hi = energ_hi # "Entry" energies
|
|
@@ -229,7 +246,9 @@ class DataRMF:
|
|
|
229
246
|
idxs = data > low_threshold
|
|
230
247
|
|
|
231
248
|
# Create a COO sparse matrix and then convert to CSR for efficiency
|
|
232
|
-
coo = sparse.COO(
|
|
249
|
+
coo = sparse.COO(
|
|
250
|
+
[rows[idxs], cols[idxs]], data[idxs], shape=(len(self.energ_lo), len(self.channel))
|
|
251
|
+
)
|
|
233
252
|
self.sparse_matrix = coo.T # .tocsr()
|
|
234
253
|
|
|
235
254
|
@property
|
jaxspec/data/util.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Literal, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Literal, TypeVar
|
|
4
4
|
|
|
5
|
-
import haiku as hk
|
|
6
5
|
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpyro
|
|
9
9
|
|
|
10
10
|
from astropy.io import fits
|
|
11
|
-
from
|
|
11
|
+
from jax.experimental.sparse import BCOO
|
|
12
12
|
from numpyro import handlers
|
|
13
13
|
|
|
14
|
-
from .._fit._build_model import CountForwardModel
|
|
15
14
|
from ..model.abc import SpectralModel
|
|
16
15
|
from ..util.online_storage import table_manager
|
|
17
16
|
from . import Instrument, ObsConfiguration, Observation
|
|
@@ -19,6 +18,10 @@ from . import Instrument, ObsConfiguration, Observation
|
|
|
19
18
|
K = TypeVar("K")
|
|
20
19
|
V = TypeVar("V")
|
|
21
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from ..data import ObsConfiguration
|
|
23
|
+
from ..model.abc import SpectralModel
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
def load_example_pha(
|
|
24
27
|
source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
|
|
@@ -127,69 +130,40 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
|
|
|
127
130
|
raise ValueError(f"{source} not recognized.")
|
|
128
131
|
|
|
129
132
|
|
|
130
|
-
def
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
Convenience function to simulate a spectrum from a given model and a set of parameters.
|
|
139
|
-
It requires an instrumental setup, and unlike in
|
|
140
|
-
[XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
|
|
141
|
-
exclusively by Poisson statistics.
|
|
142
|
-
|
|
143
|
-
Parameters:
|
|
144
|
-
instrument: The instrumental setup.
|
|
145
|
-
model: The model to use.
|
|
146
|
-
parameters: The parameters of the model.
|
|
147
|
-
rng_key: The random number generator seed.
|
|
148
|
-
sparsify_matrix: Whether to sparsify the matrix or not.
|
|
149
|
-
"""
|
|
150
|
-
|
|
151
|
-
instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
|
|
152
|
-
fakeits = []
|
|
133
|
+
def forward_model_with_multiple_inputs(
|
|
134
|
+
model: "SpectralModel",
|
|
135
|
+
parameters,
|
|
136
|
+
obs_configuration: "ObsConfiguration",
|
|
137
|
+
sparse=False,
|
|
138
|
+
):
|
|
139
|
+
energies = np.asarray(obs_configuration.in_energies)
|
|
140
|
+
parameter_dims = next(iter(parameters.values())).shape
|
|
153
141
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
hk.transform(
|
|
157
|
-
lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
|
|
158
|
-
)
|
|
159
|
-
)
|
|
142
|
+
def flux_func(p):
|
|
143
|
+
return model.photon_flux(p, *energies)
|
|
160
144
|
|
|
161
|
-
|
|
162
|
-
|
|
145
|
+
for _ in parameter_dims:
|
|
146
|
+
flux_func = jax.vmap(flux_func)
|
|
163
147
|
|
|
164
|
-
|
|
165
|
-
counts = numpyro.sample(
|
|
166
|
-
f"likelihood_obs_{i}",
|
|
167
|
-
numpyro.distributions.Poisson(obs_model(parameters)),
|
|
168
|
-
)
|
|
148
|
+
flux_func = jax.jit(flux_func)
|
|
169
149
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
instrument.exposure,
|
|
175
|
-
grouping=instrument.grouping)
|
|
176
|
-
|
|
177
|
-
observation = Observation(
|
|
178
|
-
pha=pha,
|
|
179
|
-
arf=instrument.arf,
|
|
180
|
-
rmf=instrument.rmf,
|
|
181
|
-
low_energy=instrument.low_energy,
|
|
182
|
-
high_energy=instrument.high_energy
|
|
150
|
+
if sparse:
|
|
151
|
+
# folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
152
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
153
|
+
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
183
154
|
)
|
|
184
|
-
"""
|
|
185
155
|
|
|
186
|
-
|
|
156
|
+
else:
|
|
157
|
+
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
187
158
|
|
|
188
|
-
|
|
159
|
+
expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
|
|
160
|
+
|
|
161
|
+
# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
|
|
162
|
+
return jnp.clip(expected_counts, a_min=1e-6)
|
|
189
163
|
|
|
190
164
|
|
|
191
165
|
def fakeit_for_multiple_parameters(
|
|
192
|
-
|
|
166
|
+
obsconfs: ObsConfiguration | list[ObsConfiguration],
|
|
193
167
|
model: SpectralModel,
|
|
194
168
|
parameters: Mapping[K, V],
|
|
195
169
|
rng_key: int = 0,
|
|
@@ -198,11 +172,32 @@ def fakeit_for_multiple_parameters(
|
|
|
198
172
|
):
|
|
199
173
|
"""
|
|
200
174
|
Convenience function to simulate multiple spectra from a given model and a set of parameters.
|
|
175
|
+
This is supposed to be somewhat optimized and can handle multiple parameters at once without blowing
|
|
176
|
+
up the memory. The parameters should be passed as a dictionary with the parameter name as the key and
|
|
177
|
+
the parameter values as the values, the value can be a scalar or a nd-array.
|
|
178
|
+
|
|
179
|
+
# Example:
|
|
180
|
+
|
|
181
|
+
``` python
|
|
182
|
+
from jaxspec.data.util import fakeit_for_multiple_parameters
|
|
183
|
+
from numpy.random import default_rng
|
|
201
184
|
|
|
202
|
-
|
|
185
|
+
rng = default_rng(42)
|
|
186
|
+
size = (10, 30)
|
|
187
|
+
|
|
188
|
+
parameters = {
|
|
189
|
+
"tbabs_1_nh": rng.uniform(0.1, 0.4, size=size),
|
|
190
|
+
"powerlaw_1_alpha": rng.uniform(1, 3, size=size),
|
|
191
|
+
"powerlaw_1_norm": rng.exponential(10 ** (-0.5), size=size),
|
|
192
|
+
"blackbodyrad_1_kT": rng.uniform(0.1, 3.0, size=size),
|
|
193
|
+
"blackbodyrad_1_norm": rng.exponential(10 ** (-3), size=size)
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
|
|
197
|
+
```
|
|
203
198
|
|
|
204
199
|
Parameters:
|
|
205
|
-
|
|
200
|
+
obsconfs: The observational setup(s).
|
|
206
201
|
model: The model to use.
|
|
207
202
|
parameters: The parameters of the model.
|
|
208
203
|
rng_key: The random number generator seed.
|
|
@@ -210,28 +205,23 @@ def fakeit_for_multiple_parameters(
|
|
|
210
205
|
sparsify_matrix: Whether to sparsify the matrix or not.
|
|
211
206
|
"""
|
|
212
207
|
|
|
213
|
-
|
|
208
|
+
obsconf_list = [obsconfs] if isinstance(obsconfs, ObsConfiguration) else obsconfs
|
|
214
209
|
fakeits = []
|
|
215
210
|
|
|
216
|
-
for i,
|
|
217
|
-
|
|
218
|
-
|
|
211
|
+
for i, obsconf in enumerate(obsconf_list):
|
|
212
|
+
countrate = forward_model_with_multiple_inputs(
|
|
213
|
+
model, parameters, obsconf, sparse=sparsify_matrix
|
|
219
214
|
)
|
|
220
215
|
|
|
221
|
-
@jax.jit
|
|
222
|
-
@jax.vmap
|
|
223
|
-
def obs_model(p):
|
|
224
|
-
return transformed_model.apply(None, p)
|
|
225
|
-
|
|
226
216
|
if apply_stat:
|
|
227
217
|
with handlers.seed(rng_seed=rng_key):
|
|
228
218
|
spectrum = numpyro.sample(
|
|
229
219
|
f"likelihood_obs_{i}",
|
|
230
|
-
numpyro.distributions.Poisson(
|
|
220
|
+
numpyro.distributions.Poisson(countrate),
|
|
231
221
|
)
|
|
232
222
|
|
|
233
223
|
else:
|
|
234
|
-
spectrum =
|
|
224
|
+
spectrum = countrate
|
|
235
225
|
|
|
236
226
|
fakeits.append(spectrum)
|
|
237
227
|
|