jaxspec 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -122,7 +122,9 @@ class FitResult:
122
122
 
123
123
  samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
124
124
 
125
- total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
125
+ total_shape = tuple(
126
+ posterior.sizes[d] for d in posterior.coords if not (("obs" in d) or ("bkg" in d))
127
+ )
126
128
 
127
129
  posterior = {key: posterior[key].data for key in posterior.data_vars}
128
130
 
jaxspec/data/obsconf.py CHANGED
@@ -163,10 +163,8 @@ class ObsConfiguration(xr.Dataset):
163
163
 
164
164
  if observation.folded_background is not None:
165
165
  folded_background = observation.folded_background.data[row_idx]
166
- folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
167
166
  else:
168
167
  folded_background = np.zeros_like(folded_counts)
169
- folded_background_unscaled = np.zeros_like(folded_counts)
170
168
 
171
169
  data_dict = {
172
170
  "transfer_matrix": (
@@ -208,14 +206,6 @@ class ObsConfiguration(xr.Dataset):
208
206
  "unit": "photons",
209
207
  },
210
208
  ),
211
- "folded_background_unscaled": (
212
- ["folded_channel"],
213
- folded_background_unscaled,
214
- {
215
- "description": "To be done",
216
- "unit": "photons",
217
- },
218
- ),
219
209
  }
220
210
 
221
211
  return cls(
@@ -46,16 +46,14 @@ class Observation(xr.Dataset):
46
46
  quality,
47
47
  exposure,
48
48
  background=None,
49
- background_unscaled=None,
50
49
  backratio=1.0,
51
50
  attributes: dict | None = None,
52
51
  ):
53
52
  if attributes is None:
54
53
  attributes = {}
55
54
 
56
- if background is None or background_unscaled is None:
55
+ if background is None:
57
56
  background = np.zeros_like(counts, dtype=np.int64)
58
- background_unscaled = np.zeros_like(counts, dtype=np.int64)
59
57
 
60
58
  data_dict = {
61
59
  "counts": (
@@ -86,7 +84,9 @@ class Observation(xr.Dataset):
86
84
  ),
87
85
  "folded_backratio": (
88
86
  ["folded_channel"],
89
- np.asarray(np.ma.filled(grouping @ backratio), dtype=float),
87
+ np.asarray(
88
+ np.ma.filled(grouping @ backratio) / grouping.sum(axis=1).todense(), dtype=float
89
+ ),
90
90
  {"description": "Background scaling after grouping"},
91
91
  ),
92
92
  "background": (
@@ -94,11 +94,6 @@ class Observation(xr.Dataset):
94
94
  np.asarray(background, dtype=np.int64),
95
95
  {"description": "Background counts", "unit": "photons"},
96
96
  ),
97
- "folded_background_unscaled": (
98
- ["folded_channel"],
99
- np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
100
- {"description": "Background counts", "unit": "photons"},
101
- ),
102
97
  "folded_background": (
103
98
  ["folded_channel"],
104
99
  np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
@@ -147,8 +142,7 @@ class Observation(xr.Dataset):
147
142
  pha.quality,
148
143
  pha.exposure,
149
144
  backratio=backratio,
150
- background=bkg.counts * backratio if bkg is not None else None,
151
- background_unscaled=bkg.counts if bkg is not None else None,
145
+ background=bkg.counts if bkg is not None else None,
152
146
  attributes=metadata,
153
147
  )
154
148
 
@@ -10,9 +10,8 @@ import matplotlib.pyplot as plt
10
10
  import numpyro
11
11
 
12
12
  from flax import nnx
13
- from jax.experimental import mesh_utils
14
13
  from jax.random import PRNGKey
15
- from jax.sharding import PositionalSharding
14
+ from jax.sharding import NamedSharding, PartitionSpec
16
15
  from numpyro.distributions import Poisson, TransformedDistribution
17
16
  from numpyro.infer import Predictive
18
17
  from numpyro.infer.inspect import get_model_relations
@@ -136,7 +135,7 @@ class BayesianModel(nnx.Module):
136
135
  with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
137
136
  numpyro.sample(
138
137
  "obs/~/" + name,
139
- Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
138
+ Poisson(obs_countrate + bkg_countrate * observation.folded_backratio.data),
140
139
  obs=observation.folded_counts.data if observed else None,
141
140
  )
142
141
 
@@ -244,7 +243,7 @@ class BayesianModel(nnx.Module):
244
243
  return log_posterior_prob
245
244
 
246
245
  @cached_property
247
- def _parameter_names(self) -> list[str]:
246
+ def parameter_names(self) -> list[str]:
248
247
  """
249
248
  A list of parameter names for the model.
250
249
  """
@@ -269,7 +268,7 @@ class BayesianModel(nnx.Module):
269
268
  """
270
269
  input_params = {}
271
270
 
272
- for index, key in enumerate(self._parameter_names):
271
+ for index, key in enumerate(self.parameter_names):
273
272
  input_params[key] = theta[index]
274
273
 
275
274
  return input_params
@@ -279,9 +278,9 @@ class BayesianModel(nnx.Module):
279
278
  Convert a dictionary of parameters to an array of parameters.
280
279
  """
281
280
 
282
- theta = jnp.zeros(len(self._parameter_names))
281
+ theta = jnp.zeros(len(self.parameter_names))
283
282
 
284
- for index, key in enumerate(self._parameter_names):
283
+ for index, key in enumerate(self.parameter_names):
285
284
  theta = theta.at[index].set(dict_of_params[key])
286
285
 
287
286
  return theta
@@ -298,7 +297,7 @@ class BayesianModel(nnx.Module):
298
297
  @jax.jit
299
298
  def prior_sample(key):
300
299
  return Predictive(
301
- self.numpyro_model, return_sites=self._parameter_names, num_samples=num_samples
300
+ self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
302
301
  )(key, observed=False)
303
302
 
304
303
  return prior_sample(key)
@@ -324,7 +323,8 @@ class BayesianModel(nnx.Module):
324
323
  """
325
324
  key_prior, key_posterior = jax.random.split(key, 2)
326
325
  n_devices = len(jax.local_devices())
327
- sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
326
+ mesh = jax.make_mesh((n_devices,), ("batch",))
327
+ sharding = NamedSharding(mesh, PartitionSpec("batch"))
328
328
 
329
329
  # Sample from prior and correct if the number of samples is not a multiple of the number of devices
330
330
  if num_samples % n_devices != 0:
jaxspec/fit/_fitter.py CHANGED
@@ -215,13 +215,29 @@ class VIFitter(BayesianModelFitter):
215
215
  self,
216
216
  rng_key: int = 0,
217
217
  num_steps: int = 10_000,
218
- optimizer=numpyro.optim.Adam(step_size=0.0005),
219
- loss=Trace_ELBO(),
218
+ optimizer: numpyro.optim._NumPyroOptim = numpyro.optim.Adam(step_size=0.0005),
219
+ loss: numpyro.infer.elbo.ELBO = Trace_ELBO(),
220
220
  num_samples: int = 1000,
221
- guide=None,
221
+ guide: numpyro.infer.autoguide.AutoGuide | None = None,
222
222
  use_transformed_model: bool = True,
223
223
  plot_diagnostics: bool = False,
224
224
  ) -> FitResult:
225
+ """
226
+ Fit the model to the data using a variational inference approach from numpyro.
227
+
228
+ Parameters:
229
+ rng_key: the random key used to initialize the sampler.
230
+ num_steps: the number of steps for VI.
231
+ optimizer: the optimizer to use.
232
+ num_samples: the number of samples to draw.
233
+ loss: the loss function to use.
234
+ guide: the guide to use.
235
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
236
+ plot_diagnostics: plot the loss during VI.
237
+
238
+ Returns:
239
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
240
+ """
225
241
  bayesian_model = (
226
242
  self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
227
243
  )
@@ -8,13 +8,26 @@ from numpyro.distributions import Distribution
8
8
 
9
9
 
10
10
  class GainModel(ABC, nnx.Module):
11
+ """
12
+ Generic class for a gain model
13
+ """
14
+
11
15
  @abstractmethod
12
16
  def numpyro_model(self, observation_name: str):
13
17
  pass
14
18
 
15
19
 
16
20
  class ConstantGain(GainModel):
21
+ """
22
+ A constant gain model
23
+ """
24
+
17
25
  def __init__(self, prior_distribution: Distribution):
26
+ """
27
+ Parameters:
28
+ prior_distribution: the prior distribution for the gain value.
29
+ """
30
+
18
31
  self.prior_distribution = prior_distribution
19
32
 
20
33
  def numpyro_model(self, observation_name: str):
@@ -27,13 +40,25 @@ class ConstantGain(GainModel):
27
40
 
28
41
 
29
42
  class ShiftModel(ABC, nnx.Module):
43
+ """
44
+ Generic class for a shift model
45
+ """
46
+
30
47
  @abstractmethod
31
48
  def numpyro_model(self, observation_name: str):
32
49
  pass
33
50
 
34
51
 
35
52
  class ConstantShift(ShiftModel):
53
+ """
54
+ A constant shift model
55
+ """
56
+
36
57
  def __init__(self, prior_distribution: Distribution):
58
+ """
59
+ Parameters:
60
+ prior_distribution: the prior distribution for the shift value.
61
+ """
37
62
  self.prior_distribution = prior_distribution
38
63
 
39
64
  def numpyro_model(self, observation_name: str):
@@ -52,6 +77,15 @@ class InstrumentModel(nnx.Module):
52
77
  gain_model: GainModel | None = None,
53
78
  shift_model: ShiftModel | None = None,
54
79
  ):
80
+ """
81
+ Encapsulate an instrument model, build as a combination of a shift and gain model.
82
+
83
+ Parameters:
84
+ reference_observation_name : The observation to use as a reference
85
+ gain_model : The gain model
86
+ shift_model : The shift model
87
+ """
88
+
55
89
  self.reference = reference_observation_name
56
90
  self.gain_model = gain_model
57
91
  self.shift_model = shift_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -15,18 +15,18 @@ Requires-Dist: chainconsumer<2,>=1.1.2
15
15
  Requires-Dist: cmasher<2,>=1.6.3
16
16
  Requires-Dist: flax>0.10.5
17
17
  Requires-Dist: interpax<0.4,>=0.3.5
18
- Requires-Dist: jax<0.6,>=0.5.0
18
+ Requires-Dist: jax<0.7,>=0.5.0
19
19
  Requires-Dist: jaxns<3,>=2.6.7
20
20
  Requires-Dist: jaxopt<0.9,>=0.8.3
21
21
  Requires-Dist: matplotlib<4,>=3.8.0
22
22
  Requires-Dist: mendeleev<1.2,>=0.15
23
23
  Requires-Dist: networkx~=3.1
24
24
  Requires-Dist: numpy<3.0.0
25
- Requires-Dist: numpyro<0.19,>=0.17.0
26
- Requires-Dist: optimistix<0.0.11,>=0.0.10
25
+ Requires-Dist: numpyro<0.20,>=0.17.0
26
+ Requires-Dist: optimistix<0.0.12,>=0.0.10
27
27
  Requires-Dist: pandas<3,>=2.2.0
28
28
  Requires-Dist: pooch<2,>=1.8.2
29
- Requires-Dist: scipy<1.15
29
+ Requires-Dist: scipy<1.16
30
30
  Requires-Dist: seaborn<0.14,>=0.13.1
31
31
  Requires-Dist: simpleeval<1.1.0,>=0.9.13
32
32
  Requires-Dist: sparse>0.15
@@ -2,11 +2,11 @@ jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
2
2
  jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  jaxspec/analysis/_plot.py,sha256=0xEz-e_xk7XvU6PUfbNwxaWg1-SxAF2XAqhkxWEhIFs,6239
4
4
  jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
5
- jaxspec/analysis/results.py,sha256=_qwDSsThI7FOAR6nMaJltGWlKO5Sz2wc1EQ73Y0Ghho,26013
5
+ jaxspec/analysis/results.py,sha256=tIBWmLoX43EY2BXt50ec8A-DqQ98PMd3m-FqTRT4iRE,26073
6
6
  jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
7
7
  jaxspec/data/instrument.py,sha256=RDiG_LkucvnF2XE_ghTFME6d_2YirgQUcEY0gEle6dk,4775
8
- jaxspec/data/obsconf.py,sha256=G0RwNshvbDQzw_ba8Y8NdI-cRsgEj-OlSNdeYCANqVM,10484
9
- jaxspec/data/observation.py,sha256=OHXfs7ApC8JAiG6h1teEmXu3iaNyhrqI5BCwl-qFvoM,7820
8
+ jaxspec/data/obsconf.py,sha256=bkYuD6mJgj8QmRaDVhcnXwUukVdo20xllzaI57prHag,10056
9
+ jaxspec/data/observation.py,sha256=7FHJm1jHEEFyrqxg3COsGmfdh5dg-5XnfKCp1yb5fNY,7411
10
10
  jaxspec/data/ogip.py,sha256=eMmBuROW4eMRxRHkPPyGHf933e0IcREqB8WMQFMS2lY,9810
11
11
  jaxspec/data/util.py,sha256=4_f6ByGjUEZXTwrB37dCyYaTB1pjF10Z0ho7-4GrQuc,9761
12
12
  jaxspec/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,15 +16,15 @@ jaxspec/experimental/intrument_models.py,sha256=vuRw7xypPI9YV-Hv8chVNP4ti24dCGjb
16
16
  jaxspec/experimental/nested_sampler.py,sha256=8jCAXQAe2mD5YSNSF0jia_rFWES_MzwRM3FrQQS_x7w,2807
17
17
  jaxspec/experimental/tabulated.py,sha256=H0llUiso2KGH4xUzTUSVPy-6I8D3wm707lU_Z1P5uq4,9429
18
18
  jaxspec/fit/__init__.py,sha256=OaS0-Hkb3Hd-AkE2o-KWfoWMX0NSCPY-_FP2znHf9l0,153
19
- jaxspec/fit/_bayesian_model.py,sha256=BeYukXr86Y1kEmSyiv-6QC4M2rM78Kx_MgGecu4ML98,15179
19
+ jaxspec/fit/_bayesian_model.py,sha256=7c2Twgz06QV1S9DdctdVk5YT1v7P-ln100bWXAvv7Go,15179
20
20
  jaxspec/fit/_build_model.py,sha256=pNZVuVfwOq3Pg23opH7xRv28DsSkQZpvy2Z-1hQSfNs,3219
21
- jaxspec/fit/_fitter.py,sha256=doBTJqTP5CN1OJhZHVlS3oMVOzPJyH4YqOnGevIIU68,8893
21
+ jaxspec/fit/_fitter.py,sha256=92gd1P7CNIqusGb64x_DpBcb0KcoGyfvSDiEnRbfqP0,9709
22
22
  jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  jaxspec/model/_graph_util.py,sha256=hPvHYmAxb7P3nyIecaZ7RqWOjwcZ1WvUByt_yNANiaY,4552
24
24
  jaxspec/model/abc.py,sha256=RGrqDrXVNjCy7GYBZL-l1PZ3Lpr37SsMIw7L9_B8WJ4,14773
25
25
  jaxspec/model/additive.py,sha256=rEONSy7b7lwfXIhuPqtI4y2Yhq55EqrlEtEckEe6TA0,20538
26
26
  jaxspec/model/background.py,sha256=VLSrU0YCW9GSHCtaEdcth-sp74aPyEVSizIMFkTpM7M,7759
27
- jaxspec/model/instrument.py,sha256=FSGgWCQjUfSVOzJ3YfrbQQ2abPlZ4P4ndtLL9Axcl-g,2217
27
+ jaxspec/model/instrument.py,sha256=1zLZgHmBZs8RLKTMT3Wu4bCx6JnxBUjhRIpYG2rLaZM,2947
28
28
  jaxspec/model/list.py,sha256=uC9rLEEeph10q6shat86WLACVuTSx73RGMl8Ij0jqQY,875
29
29
  jaxspec/model/multiplicative.py,sha256=odaOlF0K1KjzUDstPtWAk95ScHoZ7_XveOew6l3tbeU,8337
30
30
  jaxspec/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -35,8 +35,8 @@ jaxspec/util/integrate.py,sha256=7GwBSagmDzsF3P53tPs-oakeq0zHEwmZZS2zQlXngbE,463
35
35
  jaxspec/util/misc.py,sha256=O3qorCL1Y2X1BS2jdd36C1eDHK9QDXTSOr9kj3uqcJo,654
36
36
  jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
37
37
  jaxspec/util/typing.py,sha256=ZQM_l68qyYnIBZPz_1mKvwPMx64jvVBD8Uj6bx9sHv0,140
38
- jaxspec-0.3.0.dist-info/METADATA,sha256=92shp3kcwQIbKTSVSD7SU68InowsGVZXST0uJYvRwnQ,4199
39
- jaxspec-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
40
- jaxspec-0.3.0.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
41
- jaxspec-0.3.0.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
42
- jaxspec-0.3.0.dist-info/RECORD,,
38
+ jaxspec-0.3.2.dist-info/METADATA,sha256=10PjN7QwhbU8BoZc9f1Lga2n1u1_j4p8Lk2Syy6cJC8,4199
39
+ jaxspec-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
40
+ jaxspec-0.3.2.dist-info/entry_points.txt,sha256=4ffU5AImfcEBxgWTqopQll2YffpFldOswXRh16pd0Dc,72
41
+ jaxspec-0.3.2.dist-info/licenses/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
42
+ jaxspec-0.3.2.dist-info/RECORD,,