jaxspec 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/data/obsconf.py +0 -10
- jaxspec/data/observation.py +5 -11
- jaxspec/fit/_bayesian_model.py +1 -1
- jaxspec/fit/_fitter.py +19 -3
- jaxspec/model/instrument.py +34 -0
- {jaxspec-0.3.1.dist-info → jaxspec-0.3.2.dist-info}/METADATA +2 -2
- {jaxspec-0.3.1.dist-info → jaxspec-0.3.2.dist-info}/RECORD +10 -10
- {jaxspec-0.3.1.dist-info → jaxspec-0.3.2.dist-info}/WHEEL +0 -0
- {jaxspec-0.3.1.dist-info → jaxspec-0.3.2.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.3.1.dist-info → jaxspec-0.3.2.dist-info}/licenses/LICENSE.md +0 -0
jaxspec/data/obsconf.py
CHANGED
|
@@ -163,10 +163,8 @@ class ObsConfiguration(xr.Dataset):
|
|
|
163
163
|
|
|
164
164
|
if observation.folded_background is not None:
|
|
165
165
|
folded_background = observation.folded_background.data[row_idx]
|
|
166
|
-
folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
|
|
167
166
|
else:
|
|
168
167
|
folded_background = np.zeros_like(folded_counts)
|
|
169
|
-
folded_background_unscaled = np.zeros_like(folded_counts)
|
|
170
168
|
|
|
171
169
|
data_dict = {
|
|
172
170
|
"transfer_matrix": (
|
|
@@ -208,14 +206,6 @@ class ObsConfiguration(xr.Dataset):
|
|
|
208
206
|
"unit": "photons",
|
|
209
207
|
},
|
|
210
208
|
),
|
|
211
|
-
"folded_background_unscaled": (
|
|
212
|
-
["folded_channel"],
|
|
213
|
-
folded_background_unscaled,
|
|
214
|
-
{
|
|
215
|
-
"description": "To be done",
|
|
216
|
-
"unit": "photons",
|
|
217
|
-
},
|
|
218
|
-
),
|
|
219
209
|
}
|
|
220
210
|
|
|
221
211
|
return cls(
|
jaxspec/data/observation.py
CHANGED
|
@@ -46,16 +46,14 @@ class Observation(xr.Dataset):
|
|
|
46
46
|
quality,
|
|
47
47
|
exposure,
|
|
48
48
|
background=None,
|
|
49
|
-
background_unscaled=None,
|
|
50
49
|
backratio=1.0,
|
|
51
50
|
attributes: dict | None = None,
|
|
52
51
|
):
|
|
53
52
|
if attributes is None:
|
|
54
53
|
attributes = {}
|
|
55
54
|
|
|
56
|
-
if background is None
|
|
55
|
+
if background is None:
|
|
57
56
|
background = np.zeros_like(counts, dtype=np.int64)
|
|
58
|
-
background_unscaled = np.zeros_like(counts, dtype=np.int64)
|
|
59
57
|
|
|
60
58
|
data_dict = {
|
|
61
59
|
"counts": (
|
|
@@ -86,7 +84,9 @@ class Observation(xr.Dataset):
|
|
|
86
84
|
),
|
|
87
85
|
"folded_backratio": (
|
|
88
86
|
["folded_channel"],
|
|
89
|
-
np.asarray(
|
|
87
|
+
np.asarray(
|
|
88
|
+
np.ma.filled(grouping @ backratio) / grouping.sum(axis=1).todense(), dtype=float
|
|
89
|
+
),
|
|
90
90
|
{"description": "Background scaling after grouping"},
|
|
91
91
|
),
|
|
92
92
|
"background": (
|
|
@@ -94,11 +94,6 @@ class Observation(xr.Dataset):
|
|
|
94
94
|
np.asarray(background, dtype=np.int64),
|
|
95
95
|
{"description": "Background counts", "unit": "photons"},
|
|
96
96
|
),
|
|
97
|
-
"folded_background_unscaled": (
|
|
98
|
-
["folded_channel"],
|
|
99
|
-
np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
|
|
100
|
-
{"description": "Background counts", "unit": "photons"},
|
|
101
|
-
),
|
|
102
97
|
"folded_background": (
|
|
103
98
|
["folded_channel"],
|
|
104
99
|
np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
|
|
@@ -147,8 +142,7 @@ class Observation(xr.Dataset):
|
|
|
147
142
|
pha.quality,
|
|
148
143
|
pha.exposure,
|
|
149
144
|
backratio=backratio,
|
|
150
|
-
background=bkg.counts
|
|
151
|
-
background_unscaled=bkg.counts if bkg is not None else None,
|
|
145
|
+
background=bkg.counts if bkg is not None else None,
|
|
152
146
|
attributes=metadata,
|
|
153
147
|
)
|
|
154
148
|
|
jaxspec/fit/_bayesian_model.py
CHANGED
|
@@ -135,7 +135,7 @@ class BayesianModel(nnx.Module):
|
|
|
135
135
|
with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
|
|
136
136
|
numpyro.sample(
|
|
137
137
|
"obs/~/" + name,
|
|
138
|
-
Poisson(obs_countrate + bkg_countrate
|
|
138
|
+
Poisson(obs_countrate + bkg_countrate * observation.folded_backratio.data),
|
|
139
139
|
obs=observation.folded_counts.data if observed else None,
|
|
140
140
|
)
|
|
141
141
|
|
jaxspec/fit/_fitter.py
CHANGED
|
@@ -215,13 +215,29 @@ class VIFitter(BayesianModelFitter):
|
|
|
215
215
|
self,
|
|
216
216
|
rng_key: int = 0,
|
|
217
217
|
num_steps: int = 10_000,
|
|
218
|
-
optimizer=numpyro.optim.Adam(step_size=0.0005),
|
|
219
|
-
loss=Trace_ELBO(),
|
|
218
|
+
optimizer: numpyro.optim._NumPyroOptim = numpyro.optim.Adam(step_size=0.0005),
|
|
219
|
+
loss: numpyro.infer.elbo.ELBO = Trace_ELBO(),
|
|
220
220
|
num_samples: int = 1000,
|
|
221
|
-
guide=None,
|
|
221
|
+
guide: numpyro.infer.autoguide.AutoGuide | None = None,
|
|
222
222
|
use_transformed_model: bool = True,
|
|
223
223
|
plot_diagnostics: bool = False,
|
|
224
224
|
) -> FitResult:
|
|
225
|
+
"""
|
|
226
|
+
Fit the model to the data using a variational inference approach from numpyro.
|
|
227
|
+
|
|
228
|
+
Parameters:
|
|
229
|
+
rng_key: the random key used to initialize the sampler.
|
|
230
|
+
num_steps: the number of steps for VI.
|
|
231
|
+
optimizer: the optimizer to use.
|
|
232
|
+
num_samples: the number of samples to draw.
|
|
233
|
+
loss: the loss function to use.
|
|
234
|
+
guide: the guide to use.
|
|
235
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
236
|
+
plot_diagnostics: plot the loss during VI.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
240
|
+
"""
|
|
225
241
|
bayesian_model = (
|
|
226
242
|
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
227
243
|
)
|
jaxspec/model/instrument.py
CHANGED
|
@@ -8,13 +8,26 @@ from numpyro.distributions import Distribution
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class GainModel(ABC, nnx.Module):
|
|
11
|
+
"""
|
|
12
|
+
Generic class for a gain model
|
|
13
|
+
"""
|
|
14
|
+
|
|
11
15
|
@abstractmethod
|
|
12
16
|
def numpyro_model(self, observation_name: str):
|
|
13
17
|
pass
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
class ConstantGain(GainModel):
|
|
21
|
+
"""
|
|
22
|
+
A constant gain model
|
|
23
|
+
"""
|
|
24
|
+
|
|
17
25
|
def __init__(self, prior_distribution: Distribution):
|
|
26
|
+
"""
|
|
27
|
+
Parameters:
|
|
28
|
+
prior_distribution: the prior distribution for the gain value.
|
|
29
|
+
"""
|
|
30
|
+
|
|
18
31
|
self.prior_distribution = prior_distribution
|
|
19
32
|
|
|
20
33
|
def numpyro_model(self, observation_name: str):
|
|
@@ -27,13 +40,25 @@ class ConstantGain(GainModel):
|
|
|
27
40
|
|
|
28
41
|
|
|
29
42
|
class ShiftModel(ABC, nnx.Module):
|
|
43
|
+
"""
|
|
44
|
+
Generic class for a shift model
|
|
45
|
+
"""
|
|
46
|
+
|
|
30
47
|
@abstractmethod
|
|
31
48
|
def numpyro_model(self, observation_name: str):
|
|
32
49
|
pass
|
|
33
50
|
|
|
34
51
|
|
|
35
52
|
class ConstantShift(ShiftModel):
|
|
53
|
+
"""
|
|
54
|
+
A constant shift model
|
|
55
|
+
"""
|
|
56
|
+
|
|
36
57
|
def __init__(self, prior_distribution: Distribution):
|
|
58
|
+
"""
|
|
59
|
+
Parameters:
|
|
60
|
+
prior_distribution: the prior distribution for the shift value.
|
|
61
|
+
"""
|
|
37
62
|
self.prior_distribution = prior_distribution
|
|
38
63
|
|
|
39
64
|
def numpyro_model(self, observation_name: str):
|
|
@@ -52,6 +77,15 @@ class InstrumentModel(nnx.Module):
|
|
|
52
77
|
gain_model: GainModel | None = None,
|
|
53
78
|
shift_model: ShiftModel | None = None,
|
|
54
79
|
):
|
|
80
|
+
"""
|
|
81
|
+
Encapsulate an instrument model, build as a combination of a shift and gain model.
|
|
82
|
+
|
|
83
|
+
Parameters:
|
|
84
|
+
reference_observation_name : The observation to use as a reference
|
|
85
|
+
gain_model : The gain model
|
|
86
|
+
shift_model : The shift model
|
|
87
|
+
"""
|
|
88
|
+
|
|
55
89
|
self.reference = reference_observation_name
|
|
56
90
|
self.gain_model = gain_model
|
|
57
91
|
self.shift_model = shift_model
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
|
|
@@ -23,7 +23,7 @@ Requires-Dist: mendeleev<1.2,>=0.15
|
|
|
23
23
|
Requires-Dist: networkx~=3.1
|
|
24
24
|
Requires-Dist: numpy<3.0.0
|
|
25
25
|
Requires-Dist: numpyro<0.20,>=0.17.0
|
|
26
|
-
Requires-Dist: optimistix<0.0.
|
|
26
|
+
Requires-Dist: optimistix<0.0.12,>=0.0.10
|
|
27
27
|
Requires-Dist: pandas<3,>=2.2.0
|
|
28
28
|
Requires-Dist: pooch<2,>=1.8.2
|
|
29
29
|
Requires-Dist: scipy<1.16
|
|
@@ -5,8 +5,8 @@ jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,7
|
|
|
5
5
|
jaxspec/analysis/results.py,sha256=tIBWmLoX43EY2BXt50ec8A-DqQ98PMd3m-FqTRT4iRE,26073
|
|
6
6
|
jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
|
|
7
7
|
jaxspec/data/instrument.py,sha256=RDiG_LkucvnF2XE_ghTFME6d_2YirgQUcEY0gEle6dk,4775
|
|
8
|
-
jaxspec/data/obsconf.py,sha256=
|
|
9
|
-
jaxspec/data/observation.py,sha256=
|
|
8
|
+
jaxspec/data/obsconf.py,sha256=bkYuD6mJgj8QmRaDVhcnXwUukVdo20xllzaI57prHag,10056
|
|
9
|
+
jaxspec/data/observation.py,sha256=7FHJm1jHEEFyrqxg3COsGmfdh5dg-5XnfKCp1yb5fNY,7411
|
|
10
10
|
jaxspec/data/ogip.py,sha256=eMmBuROW4eMRxRHkPPyGHf933e0IcREqB8WMQFMS2lY,9810
|
|
11
11
|
jaxspec/data/util.py,sha256=4_f6ByGjUEZXTwrB37dCyYaTB1pjF10Z0ho7-4GrQuc,9761
|
|
12
12
|
jaxspec/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -16,15 +16,15 @@ jaxspec/experimental/intrument_models.py,sha256=vuRw7xypPI9YV-Hv8chVNP4ti24dCGjb
|
|
|
16
16
|
jaxspec/experimental/nested_sampler.py,sha256=8jCAXQAe2mD5YSNSF0jia_rFWES_MzwRM3FrQQS_x7w,2807
|
|
17
17
|
jaxspec/experimental/tabulated.py,sha256=H0llUiso2KGH4xUzTUSVPy-6I8D3wm707lU_Z1P5uq4,9429
|
|
18
18
|
jaxspec/fit/__init__.py,sha256=OaS0-Hkb3Hd-AkE2o-KWfoWMX0NSCPY-_FP2znHf9l0,153
|
|
19
|
-
jaxspec/fit/_bayesian_model.py,sha256=
|
|
19
|
+
jaxspec/fit/_bayesian_model.py,sha256=7c2Twgz06QV1S9DdctdVk5YT1v7P-ln100bWXAvv7Go,15179
|
|
20
20
|
jaxspec/fit/_build_model.py,sha256=pNZVuVfwOq3Pg23opH7xRv28DsSkQZpvy2Z-1hQSfNs,3219
|
|
21
|
-
jaxspec/fit/_fitter.py,sha256=
|
|
21
|
+
jaxspec/fit/_fitter.py,sha256=92gd1P7CNIqusGb64x_DpBcb0KcoGyfvSDiEnRbfqP0,9709
|
|
22
22
|
jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
23
|
jaxspec/model/_graph_util.py,sha256=hPvHYmAxb7P3nyIecaZ7RqWOjwcZ1WvUByt_yNANiaY,4552
|
|
24
24
|
jaxspec/model/abc.py,sha256=RGrqDrXVNjCy7GYBZL-l1PZ3Lpr37SsMIw7L9_B8WJ4,14773
|
|
25
25
|
jaxspec/model/additive.py,sha256=rEONSy7b7lwfXIhuPqtI4y2Yhq55EqrlEtEckEe6TA0,20538
|
|
26
26
|
jaxspec/model/background.py,sha256=VLSrU0YCW9GSHCtaEdcth-sp74aPyEVSizIMFkTpM7M,7759
|
|
27
|
-
jaxspec/model/instrument.py,sha256=
|
|
27
|
+
jaxspec/model/instrument.py,sha256=1zLZgHmBZs8RLKTMT3Wu4bCx6JnxBUjhRIpYG2rLaZM,2947
|
|
28
28
|
jaxspec/model/list.py,sha256=uC9rLEEeph10q6shat86WLACVuTSx73RGMl8Ij0jqQY,875
|
|
29
29
|
jaxspec/model/multiplicative.py,sha256=odaOlF0K1KjzUDstPtWAk95ScHoZ7_XveOew6l3tbeU,8337
|
|
30
30
|
jaxspec/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -35,8 +35,8 @@ jaxspec/util/integrate.py,sha256=7GwBSagmDzsF3P53tPs-oakeq0zHEwmZZS2zQlXngbE,463
|
|
|
35
35
|
jaxspec/util/misc.py,sha256=O3qorCL1Y2X1BS2jdd36C1eDHK9QDXTSOr9kj3uqcJo,654
|
|
36
36
|
jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
|
|
37
37
|
jaxspec/util/typing.py,sha256=ZQM_l68qyYnIBZPz_1mKvwPMx64jvVBD8Uj6bx9sHv0,140
|
|
38
|
-
jaxspec-0.3.
|
|
39
|
-
jaxspec-0.3.
|
|
40
|
-
jaxspec-0.3.
|
|
41
|
-
jaxspec-0.3.
|
|
42
|
-
jaxspec-0.3.
|
|
38
|
+
jaxspec-0.3.2.dist-info/METADATA,sha256=10PjN7QwhbU8BoZc9f1Lga2n1u1_j4p8Lk2Syy6cJC8,4199
|
|
39
|
+
jaxspec-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
40
|
+
jaxspec-0.3.2.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
|
|
41
|
+
jaxspec-0.3.2.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
|
|
42
|
+
jaxspec-0.3.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|