jaxspec 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/data/obsconf.py CHANGED
@@ -163,10 +163,8 @@ class ObsConfiguration(xr.Dataset):
163
163
 
164
164
  if observation.folded_background is not None:
165
165
  folded_background = observation.folded_background.data[row_idx]
166
- folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
167
166
  else:
168
167
  folded_background = np.zeros_like(folded_counts)
169
- folded_background_unscaled = np.zeros_like(folded_counts)
170
168
 
171
169
  data_dict = {
172
170
  "transfer_matrix": (
@@ -208,14 +206,6 @@ class ObsConfiguration(xr.Dataset):
208
206
  "unit": "photons",
209
207
  },
210
208
  ),
211
- "folded_background_unscaled": (
212
- ["folded_channel"],
213
- folded_background_unscaled,
214
- {
215
- "description": "To be done",
216
- "unit": "photons",
217
- },
218
- ),
219
209
  }
220
210
 
221
211
  return cls(
@@ -46,16 +46,14 @@ class Observation(xr.Dataset):
46
46
  quality,
47
47
  exposure,
48
48
  background=None,
49
- background_unscaled=None,
50
49
  backratio=1.0,
51
50
  attributes: dict | None = None,
52
51
  ):
53
52
  if attributes is None:
54
53
  attributes = {}
55
54
 
56
- if background is None or background_unscaled is None:
55
+ if background is None:
57
56
  background = np.zeros_like(counts, dtype=np.int64)
58
- background_unscaled = np.zeros_like(counts, dtype=np.int64)
59
57
 
60
58
  data_dict = {
61
59
  "counts": (
@@ -86,7 +84,9 @@ class Observation(xr.Dataset):
86
84
  ),
87
85
  "folded_backratio": (
88
86
  ["folded_channel"],
89
- np.asarray(np.ma.filled(grouping @ backratio), dtype=float),
87
+ np.asarray(
88
+ np.ma.filled(grouping @ backratio) / grouping.sum(axis=1).todense(), dtype=float
89
+ ),
90
90
  {"description": "Background scaling after grouping"},
91
91
  ),
92
92
  "background": (
@@ -94,11 +94,6 @@ class Observation(xr.Dataset):
94
94
  np.asarray(background, dtype=np.int64),
95
95
  {"description": "Background counts", "unit": "photons"},
96
96
  ),
97
- "folded_background_unscaled": (
98
- ["folded_channel"],
99
- np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
100
- {"description": "Background counts", "unit": "photons"},
101
- ),
102
97
  "folded_background": (
103
98
  ["folded_channel"],
104
99
  np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
@@ -147,8 +142,7 @@ class Observation(xr.Dataset):
147
142
  pha.quality,
148
143
  pha.exposure,
149
144
  backratio=backratio,
150
- background=bkg.counts * backratio if bkg is not None else None,
151
- background_unscaled=bkg.counts if bkg is not None else None,
145
+ background=bkg.counts if bkg is not None else None,
152
146
  attributes=metadata,
153
147
  )
154
148
 
@@ -135,7 +135,7 @@ class BayesianModel(nnx.Module):
135
135
  with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
136
136
  numpyro.sample(
137
137
  "obs/~/" + name,
138
- Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
138
+ Poisson(obs_countrate + bkg_countrate * observation.folded_backratio.data),
139
139
  obs=observation.folded_counts.data if observed else None,
140
140
  )
141
141
 
jaxspec/fit/_fitter.py CHANGED
@@ -215,13 +215,29 @@ class VIFitter(BayesianModelFitter):
215
215
  self,
216
216
  rng_key: int = 0,
217
217
  num_steps: int = 10_000,
218
- optimizer=numpyro.optim.Adam(step_size=0.0005),
219
- loss=Trace_ELBO(),
218
+ optimizer: numpyro.optim._NumPyroOptim = numpyro.optim.Adam(step_size=0.0005),
219
+ loss: numpyro.infer.elbo.ELBO = Trace_ELBO(),
220
220
  num_samples: int = 1000,
221
- guide=None,
221
+ guide: numpyro.infer.autoguide.AutoGuide | None = None,
222
222
  use_transformed_model: bool = True,
223
223
  plot_diagnostics: bool = False,
224
224
  ) -> FitResult:
225
+ """
226
+ Fit the model to the data using a variational inference approach from numpyro.
227
+
228
+ Parameters:
229
+ rng_key: the random key used to initialize the sampler.
230
+ num_steps: the number of steps for VI.
231
+ optimizer: the optimizer to use.
232
+ num_samples: the number of samples to draw.
233
+ loss: the loss function to use.
234
+ guide: the guide to use.
235
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
236
+ plot_diagnostics: plot the loss during VI.
237
+
238
+ Returns:
239
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
240
+ """
225
241
  bayesian_model = (
226
242
  self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
227
243
  )
@@ -8,13 +8,26 @@ from numpyro.distributions import Distribution
8
8
 
9
9
 
10
10
  class GainModel(ABC, nnx.Module):
11
+ """
12
+ Generic class for a gain model
13
+ """
14
+
11
15
  @abstractmethod
12
16
  def numpyro_model(self, observation_name: str):
13
17
  pass
14
18
 
15
19
 
16
20
  class ConstantGain(GainModel):
21
+ """
22
+ A constant gain model
23
+ """
24
+
17
25
  def __init__(self, prior_distribution: Distribution):
26
+ """
27
+ Parameters:
28
+ prior_distribution: the prior distribution for the gain value.
29
+ """
30
+
18
31
  self.prior_distribution = prior_distribution
19
32
 
20
33
  def numpyro_model(self, observation_name: str):
@@ -27,13 +40,25 @@ class ConstantGain(GainModel):
27
40
 
28
41
 
29
42
  class ShiftModel(ABC, nnx.Module):
43
+ """
44
+ Generic class for a shift model
45
+ """
46
+
30
47
  @abstractmethod
31
48
  def numpyro_model(self, observation_name: str):
32
49
  pass
33
50
 
34
51
 
35
52
  class ConstantShift(ShiftModel):
53
+ """
54
+ A constant shift model
55
+ """
56
+
36
57
  def __init__(self, prior_distribution: Distribution):
58
+ """
59
+ Parameters:
60
+ prior_distribution: the prior distribution for the shift value.
61
+ """
37
62
  self.prior_distribution = prior_distribution
38
63
 
39
64
  def numpyro_model(self, observation_name: str):
@@ -52,6 +77,15 @@ class InstrumentModel(nnx.Module):
52
77
  gain_model: GainModel | None = None,
53
78
  shift_model: ShiftModel | None = None,
54
79
  ):
80
+ """
81
+ Encapsulate an instrument model, build as a combination of a shift and gain model.
82
+
83
+ Parameters:
84
+ reference_observation_name : The observation to use as a reference
85
+ gain_model : The gain model
86
+ shift_model : The shift model
87
+ """
88
+
55
89
  self.reference = reference_observation_name
56
90
  self.gain_model = gain_model
57
91
  self.shift_model = shift_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -23,7 +23,7 @@ Requires-Dist: mendeleev<1.2,>=0.15
23
23
  Requires-Dist: networkx~=3.1
24
24
  Requires-Dist: numpy<3.0.0
25
25
  Requires-Dist: numpyro<0.20,>=0.17.0
26
- Requires-Dist: optimistix<0.0.11,>=0.0.10
26
+ Requires-Dist: optimistix<0.0.12,>=0.0.10
27
27
  Requires-Dist: pandas<3,>=2.2.0
28
28
  Requires-Dist: pooch<2,>=1.8.2
29
29
  Requires-Dist: scipy<1.16
@@ -5,8 +5,8 @@ jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,7
5
5
  jaxspec/analysis/results.py,sha256=tIBWmLoX43EY2BXt50ec8A-DqQ98PMd3m-FqTRT4iRE,26073
6
6
  jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
7
7
  jaxspec/data/instrument.py,sha256=RDiG_LkucvnF2XE_ghTFME6d_2YirgQUcEY0gEle6dk,4775
8
- jaxspec/data/obsconf.py,sha256=G0RwNshvbDQzw_ba8Y8NdI-cRsgEj-OlSNdeYCANqVM,10484
9
- jaxspec/data/observation.py,sha256=OHXfs7ApC8JAiG6h1teEmXu3iaNyhrqI5BCwl-qFvoM,7820
8
+ jaxspec/data/obsconf.py,sha256=bkYuD6mJgj8QmRaDVhcnXwUukVdo20xllzaI57prHag,10056
9
+ jaxspec/data/observation.py,sha256=7FHJm1jHEEFyrqxg3COsGmfdh5dg-5XnfKCp1yb5fNY,7411
10
10
  jaxspec/data/ogip.py,sha256=eMmBuROW4eMRxRHkPPyGHf933e0IcREqB8WMQFMS2lY,9810
11
11
  jaxspec/data/util.py,sha256=4_f6ByGjUEZXTwrB37dCyYaTB1pjF10Z0ho7-4GrQuc,9761
12
12
  jaxspec/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,15 +16,15 @@ jaxspec/experimental/intrument_models.py,sha256=vuRw7xypPI9YV-Hv8chVNP4ti24dCGjb
16
16
  jaxspec/experimental/nested_sampler.py,sha256=8jCAXQAe2mD5YSNSF0jia_rFWES_MzwRM3FrQQS_x7w,2807
17
17
  jaxspec/experimental/tabulated.py,sha256=H0llUiso2KGH4xUzTUSVPy-6I8D3wm707lU_Z1P5uq4,9429
18
18
  jaxspec/fit/__init__.py,sha256=OaS0-Hkb3Hd-AkE2o-KWfoWMX0NSCPY-_FP2znHf9l0,153
19
- jaxspec/fit/_bayesian_model.py,sha256=jSCzAzoAhsmUX7mKUikbUR9A1ZNIaY6rdPOxq6OZSU0,15179
19
+ jaxspec/fit/_bayesian_model.py,sha256=7c2Twgz06QV1S9DdctdVk5YT1v7P-ln100bWXAvv7Go,15179
20
20
  jaxspec/fit/_build_model.py,sha256=pNZVuVfwOq3Pg23opH7xRv28DsSkQZpvy2Z-1hQSfNs,3219
21
- jaxspec/fit/_fitter.py,sha256=doBTJqTP5CN1OJhZHVlS3oMVOzPJyH4YqOnGevIIU68,8893
21
+ jaxspec/fit/_fitter.py,sha256=92gd1P7CNIqusGb64x_DpBcb0KcoGyfvSDiEnRbfqP0,9709
22
22
  jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  jaxspec/model/_graph_util.py,sha256=hPvHYmAxb7P3nyIecaZ7RqWOjwcZ1WvUByt_yNANiaY,4552
24
24
  jaxspec/model/abc.py,sha256=RGrqDrXVNjCy7GYBZL-l1PZ3Lpr37SsMIw7L9_B8WJ4,14773
25
25
  jaxspec/model/additive.py,sha256=rEONSy7b7lwfXIhuPqtI4y2Yhq55EqrlEtEckEe6TA0,20538
26
26
  jaxspec/model/background.py,sha256=VLSrU0YCW9GSHCtaEdcth-sp74aPyEVSizIMFkTpM7M,7759
27
- jaxspec/model/instrument.py,sha256=FSGgWCQjUfSVOzJ3YfrbQQ2abPlZ4P4ndtLL9Axcl-g,2217
27
+ jaxspec/model/instrument.py,sha256=1zLZgHmBZs8RLKTMT3Wu4bCx6JnxBUjhRIpYG2rLaZM,2947
28
28
  jaxspec/model/list.py,sha256=uC9rLEEeph10q6shat86WLACVuTSx73RGMl8Ij0jqQY,875
29
29
  jaxspec/model/multiplicative.py,sha256=odaOlF0K1KjzUDstPtWAk95ScHoZ7_XveOew6l3tbeU,8337
30
30
  jaxspec/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -35,8 +35,8 @@ jaxspec/util/integrate.py,sha256=7GwBSagmDzsF3P53tPs-oakeq0zHEwmZZS2zQlXngbE,463
35
35
  jaxspec/util/misc.py,sha256=O3qorCL1Y2X1BS2jdd36C1eDHK9QDXTSOr9kj3uqcJo,654
36
36
  jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
37
37
  jaxspec/util/typing.py,sha256=ZQM_l68qyYnIBZPz_1mKvwPMx64jvVBD8Uj6bx9sHv0,140
38
- jaxspec-0.3.1.dist-info/METADATA,sha256=8i1cuzZY4iwIjWEhIPXBiC-Z8Y4Vv27omIMwTKnoPwo,4199
39
- jaxspec-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
40
- jaxspec-0.3.1.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
41
- jaxspec-0.3.1.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
42
- jaxspec-0.3.1.dist-info/RECORD,,
38
+ jaxspec-0.3.2.dist-info/METADATA,sha256=10PjN7QwhbU8BoZc9f1Lga2n1u1_j4p8Lk2Syy6cJC8,4199
39
+ jaxspec-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
40
+ jaxspec-0.3.2.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
41
+ jaxspec-0.3.2.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
42
+ jaxspec-0.3.2.dist-info/RECORD,,