jaxspec 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/data/util.py CHANGED
@@ -6,9 +6,9 @@ import haiku as hk
6
6
  from pathlib import Path
7
7
  from numpy.typing import ArrayLike
8
8
  from collections.abc import Mapping
9
- from typing import TypeVar
9
+ from typing import TypeVar, Tuple
10
+ from astropy.io import fits
10
11
 
11
- from .ogip import DataPHA, DataARF, DataRMF
12
12
  from . import Observation, Instrument, ObsConfiguration
13
13
  from ..model.abc import SpectralModel
14
14
  from ..fit import CountForwardModel
@@ -104,6 +104,7 @@ def fakeit(
104
104
  model: SpectralModel,
105
105
  parameters: Mapping[K, V],
106
106
  rng_key: int = 0,
107
+ sparsify_matrix: bool = False,
107
108
  ) -> ArrayLike | list[ArrayLike]:
108
109
  """
109
110
  This function is a convenience function that allows to simulate spectra from a given model and a set of parameters.
@@ -116,13 +117,16 @@ def fakeit(
116
117
  model: The model to use.
117
118
  parameters: The parameters of the model.
118
119
  rng_key: The random number generator seed.
120
+ sparsify_matrix: Whether to sparsify the matrix or not.
119
121
  """
120
122
 
121
123
  instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
122
124
  fakeits = []
123
125
 
124
126
  for i, instrument in enumerate(instruments):
125
- transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, instrument)(par)))
127
+ transformed_model = hk.without_apply_rng(
128
+ hk.transform(lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par))
129
+ )
126
130
 
127
131
  def obs_model(p):
128
132
  return transformed_model.apply(None, p)
@@ -159,7 +163,8 @@ def fakeit_for_multiple_parameters(
159
163
  model: SpectralModel,
160
164
  parameters: Mapping[K, V],
161
165
  rng_key: int = 0,
162
- apply_stat=True,
166
+ apply_stat: bool = True,
167
+ sparsify_matrix: bool = False,
163
168
  ):
164
169
  """
165
170
  This function is a convenience function that allows to simulate spectra multiple spectra from a given model and a
@@ -173,13 +178,16 @@ def fakeit_for_multiple_parameters(
173
178
  parameters: The parameters of the model.
174
179
  rng_key: The random number generator seed.
175
180
  apply_stat: Whether to apply Poisson statistic on the folded spectra or not.
181
+ sparsify_matrix: Whether to sparsify the matrix or not.
176
182
  """
177
183
 
178
184
  instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
179
185
  fakeits = []
180
186
 
181
187
  for i, obs in enumerate(instruments):
182
- transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs)(par)))
188
+ transformed_model = hk.without_apply_rng(
189
+ hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparsify_matrix)(par))
190
+ )
183
191
 
184
192
  @jax.jit
185
193
  @jax.vmap
@@ -201,46 +209,32 @@ def fakeit_for_multiple_parameters(
201
209
  return fakeits[0] if len(fakeits) == 1 else fakeits
202
210
 
203
211
 
204
- def data_loader(pha_path: str, arf_path=None, rmf_path=None, bkg_path=None):
212
+ def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
205
213
  """
206
- This function is a convenience function that allows to load PHA, ARF and RMF data
207
- from a given PHA file, using either the ARF/RMF/BKG filenames in the header or the
208
- specified filenames overwritten by the user.
209
-
214
+ This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
210
215
  Parameters:
211
216
  pha_path: The PHA file path.
217
+
218
+ Returns:
212
219
  arf_path: The ARF file path.
213
220
  rmf_path: The RMF file path.
214
221
  bkg_path: The BKG file path.
215
222
  """
216
223
 
217
- pha = DataPHA.from_file(pha_path)
218
- directory = str(Path(pha_path).parent)
219
-
220
- if arf_path is None:
221
- if pha.ancrfile != "none" and pha.ancrfile != "":
222
- arf_path = find_file_or_compressed_in_dir(pha.ancrfile, directory)
223
-
224
- if rmf_path is None:
225
- if pha.respfile != "none" and pha.respfile != "":
226
- rmf_path = find_file_or_compressed_in_dir(pha.respfile, directory)
227
-
228
- if bkg_path is None:
229
- if pha.backfile.lower() != "none" and pha.backfile != "":
230
- bkg_path = find_file_or_compressed_in_dir(pha.backfile, directory)
224
+ def find_path(file_name: str, directory: str) -> str | None:
225
+ if file_name.lower() != "none" and file_name != "":
226
+ return find_file_or_compressed_in_dir(file_name, directory)
227
+ else:
228
+ return None
231
229
 
232
- arf = DataARF.from_file(arf_path) if arf_path is not None else None
233
- rmf = DataRMF.from_file(rmf_path) if rmf_path is not None else None
234
- bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
230
+ header = fits.getheader(pha_path, "SPECTRUM")
231
+ directory = str(Path(pha_path).parent)
235
232
 
236
- metadata = {
237
- "observation_file": pha_path,
238
- "background_file": bkg_path,
239
- "response_matrix_file": rmf_path,
240
- "ancillary_response_file": arf_path,
241
- }
233
+ arf_path = find_path(header.get("ANCRFILE", "none"), directory)
234
+ rmf_path = find_path(header.get("RESPFILE", "none"), directory)
235
+ bkg_path = find_path(header.get("BACKFILE", "none"), directory)
242
236
 
243
- return pha, arf, rmf, bkg, metadata
237
+ return arf_path, rmf_path, bkg_path
244
238
 
245
239
 
246
240
  def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
jaxspec/fit.py CHANGED
@@ -4,11 +4,12 @@ import numpyro
4
4
  import arviz as az
5
5
  import jax
6
6
  from typing import Callable, TypeVar
7
- from abc import ABC
7
+ from abc import ABC, abstractmethod
8
8
  from jax import random
9
9
  from jax.tree_util import tree_map
10
+ from jax.flatten_util import ravel_pytree
10
11
  from jax.experimental.sparse import BCSR
11
- from .analysis.results import ChainResult
12
+ from .analysis.results import FitResult
12
13
  from .model.abc import SpectralModel
13
14
  from .data import ObsConfiguration
14
15
  from .model.background import BackgroundModel
@@ -17,13 +18,15 @@ from numpyro.distributions import Distribution, TransformedDistribution
17
18
  from numpyro.distributions import Poisson
18
19
  from jax.typing import ArrayLike
19
20
  from numpyro.infer.reparam import TransformReparam
21
+ from numpyro.infer.util import initialize_model
22
+ from jax.random import PRNGKey
23
+ import jaxopt
20
24
 
21
25
 
22
26
  T = TypeVar("T")
23
27
 
24
28
 
25
- class HaikuDict(dict[str, dict[str, T]]):
26
- ...
29
+ class HaikuDict(dict[str, dict[str, T]]): ...
27
30
 
28
31
 
29
32
  def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple = ()):
@@ -32,7 +35,8 @@ def build_prior(prior: HaikuDict[Distribution | ArrayLike], expand_shape: tuple
32
35
  for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
33
36
  match sample:
34
37
  case Distribution():
35
- parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape))
38
+ parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{m}_{n}", sample)
39
+ # parameters[m][n] = numpyro.sample(f"{m}_{n}", sample.expand(expand_shape)) build a free parameter for each obs
36
40
  case float() | ArrayLike():
37
41
  parameters[m][n] = jnp.ones(expand_shape) * sample
38
42
  case _:
@@ -54,7 +58,7 @@ def build_numpyro_model(
54
58
 
55
59
  if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
56
60
  bkg_countrate = background_model.numpyro_model(
57
- obs.out_energies, obs.folded_background.data, name=name + "bkg", observed=observed
61
+ obs.out_energies, obs.folded_background.data, name="bkg_" + name, observed=observed
58
62
  )
59
63
  elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
60
64
  raise ValueError("Trying to fit a background model but no background is linked to this observation")
@@ -66,9 +70,9 @@ def build_numpyro_model(
66
70
  countrate = obs_model(prior_params)
67
71
 
68
72
  # This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
69
- with numpyro.plate(name + "obs_plate", len(obs.folded_counts)):
73
+ with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
70
74
  numpyro.sample(
71
- name + "obs",
75
+ "obs_" + name,
72
76
  Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
73
77
  obs=obs.folded_counts.data if observed else None,
74
78
  )
@@ -76,6 +80,25 @@ def build_numpyro_model(
76
80
  return numpro_model
77
81
 
78
82
 
83
+ def filter_inference_data(inference_data, observation_container, background_model=None) -> az.InferenceData:
84
+ predictive_parameters = []
85
+
86
+ for key, value in observation_container.items():
87
+ if background_model is not None:
88
+ predictive_parameters.append(f"obs_{key}")
89
+ predictive_parameters.append(f"bkg_{key}")
90
+ else:
91
+ predictive_parameters.append(f"obs_{key}")
92
+
93
+ inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
94
+
95
+ parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
96
+ inference_data.posterior = inference_data.posterior[parameters]
97
+ inference_data.prior = inference_data.prior[parameters]
98
+
99
+ return inference_data
100
+
101
+
79
102
  class CountForwardModel(hk.Module):
80
103
  """
81
104
  A haiku module which allows to build the function that simulates the measured counts
@@ -97,27 +120,87 @@ class CountForwardModel(hk.Module):
97
120
  Compute the count functions for a given observation.
98
121
  """
99
122
 
100
- expected_counts = self.transfer_matrix @ self.model(parameters, *self.energies)
123
+ expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
101
124
 
102
125
  return jnp.clip(expected_counts, a_min=1e-6)
103
126
 
104
127
 
105
- class BayesianModelAbstract(ABC):
128
+ class ModelFitter(ABC):
106
129
  """
107
130
  Abstract class to fit a model to a given set of observation.
108
131
  """
109
132
 
110
- model: SpectralModel
111
- """The model to fit to the data."""
112
- numpyro_model: Callable
113
- """The numpyro model defining the likelihood."""
114
- background_model: BackgroundModel
115
- """The background model."""
116
- pars: dict
133
+ def __init__(
134
+ self,
135
+ model: SpectralModel,
136
+ observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
137
+ background_model: BackgroundModel = None,
138
+ sparsify_matrix: bool = False,
139
+ ):
140
+ """
141
+ Initialize the fitter.
117
142
 
118
- def __init__(self, model: SpectralModel):
143
+ Parameters:
144
+ model: the spectral model to fit.
145
+ observations: the observations to fit the model to.
146
+ background_model: the background model to fit.
147
+ sparsify_matrix: whether to sparsify the transfer matrix.
148
+ """
119
149
  self.model = model
150
+ self._observations = observations
151
+ self.background_model = background_model
120
152
  self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
153
+ self.sparse = sparsify_matrix
154
+
155
+ @property
156
+ def _observation_container(self) -> dict[str, ObsConfiguration]:
157
+ """
158
+ The observations used in the fit as a dictionary of observations.
159
+ """
160
+
161
+ if isinstance(self._observations, dict):
162
+ return self._observations
163
+
164
+ elif isinstance(self._observations, list):
165
+ return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
166
+
167
+ elif isinstance(self._observations, ObsConfiguration):
168
+ return {"data": self._observations}
169
+
170
+ else:
171
+ raise ValueError(f"Invalid type for observations : {type(self._observations)}")
172
+
173
+ def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
174
+ """
175
+ Build the numpyro model using the observed data, the prior distributions and the spectral model.
176
+
177
+ Parameters:
178
+ prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
179
+
180
+ Returns:
181
+ A model function that can be used with numpyro.
182
+ """
183
+
184
+ def model(observed=True):
185
+ prior_params = build_prior(prior_distributions, expand_shape=(len(self._observation_container),))
186
+
187
+ for i, (key, observation) in enumerate(self._observation_container.items()):
188
+ params = tree_map(lambda x: x[i], prior_params)
189
+
190
+ obs_model = build_numpyro_model(observation, self.model, self.background_model, name=key, sparse=self.sparse)
191
+ obs_model(params, observed=observed)
192
+
193
+ return model
194
+
195
+ @abstractmethod
196
+ def fit(self, prior_distributions: HaikuDict[Distribution], **kwargs) -> FitResult: ...
197
+
198
+
199
+ class BayesianFitter(ModelFitter):
200
+ """
201
+ A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
202
+ from numpyro to perform the inference on the model parameters.
203
+ """
121
204
 
122
205
  def fit(
123
206
  self,
@@ -128,11 +211,11 @@ class BayesianModelAbstract(ABC):
128
211
  num_samples: int = 1000,
129
212
  max_tree_depth: int = 10,
130
213
  target_accept_prob: float = 0.8,
131
- dense_mass=False,
214
+ dense_mass: bool = False,
132
215
  mcmc_kwargs: dict = {},
133
- ) -> ChainResult:
216
+ ) -> FitResult:
134
217
  """
135
- Fit the model to the data using NUTS sampler from numpyro. This is the default sampler in jaxspec.
218
+ Fit the model to the data using NUTS sampler from numpyro.
136
219
 
137
220
  Parameters:
138
221
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
@@ -146,7 +229,7 @@ class BayesianModelAbstract(ABC):
146
229
  mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
147
230
 
148
231
  Returns:
149
- A [`ChainResult`][jaxspec.analysis.results.ChainResult] instance containing the results of the fit.
232
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
150
233
  """
151
234
 
152
235
  transform_dict = {}
@@ -174,80 +257,91 @@ class BayesianModelAbstract(ABC):
174
257
  prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
175
258
  inference_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive)
176
259
 
177
- predictive_parameters = ["obs", "bkg"] if self.background_model is not None else ["obs"]
178
- inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
179
-
180
- parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
181
- inference_data.posterior = inference_data.posterior[parameters]
182
- inference_data.prior = inference_data.prior[parameters]
260
+ inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
183
261
 
184
- return ChainResult(
262
+ return FitResult(
185
263
  self.model,
186
- self.observation,
264
+ self._observation_container,
187
265
  inference_data,
188
- mcmc.get_samples(),
189
266
  self.model.params,
190
267
  background_model=self.background_model,
191
268
  )
192
269
 
193
270
 
194
- class BayesianModel(BayesianModelAbstract):
271
+ class MinimizationFitter(ModelFitter):
195
272
  """
196
- Class to fit a model to a given observation using a Bayesian approach.
273
+ A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
274
+ algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
275
+ Hessian of the log-likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
276
+ numpyro.
197
277
  """
198
278
 
199
- def __init__(self, model, observation, background_model: BackgroundModel = None, sparsify_matrix: bool = False):
200
- super().__init__(model)
201
- self.observation = observation
202
- self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
203
- self.sparse = sparsify_matrix
204
- self.background_model = background_model
205
-
206
- def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
279
+ def fit(
280
+ self,
281
+ prior_distributions: HaikuDict[Distribution],
282
+ rng_key: int = 0,
283
+ num_iter_max: int = 10_000,
284
+ num_samples: int = 1_000,
285
+ ) -> FitResult:
207
286
  """
208
- Build the numpyro model for the Bayesian fit. It returns a callable which can be used
209
- to fit the model using numpyro's various samplers.
287
+ Fit the model to the data using L-BFGS algorithm.
210
288
 
211
289
  Parameters:
212
290
  prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
213
- """
214
-
215
- def model(observed=True):
216
- prior_params = build_prior(prior_distributions)
217
- obs_model = build_numpyro_model(self.observation, self.model, self.background_model, sparse=self.sparse)
218
- obs_model(prior_params, observed=observed)
219
-
220
- return model
291
+ rng_key: the random key used to initialize the sampler.
292
+ num_iter_max: the maximum number of iteration in the minimization algorithm.
293
+ num_samples: the number of sample to draw from the best-fit covariance.
221
294
 
295
+ Returns:
296
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
297
+ """
222
298
 
223
- """
224
- class MultipleObservationMCMC(BayesianModelAbstract):
299
+ bayesian_model = self.numpyro_model(prior_distributions)
225
300
 
226
- def __init__(self, model, observations, background_model: BackgroundModel = None):
301
+ param_info, potential_fn, postprocess_fn, *_ = initialize_model(
302
+ PRNGKey(0),
303
+ bayesian_model,
304
+ model_args=tuple(),
305
+ dynamic_args=True, # <- this is important!
306
+ )
227
307
 
228
- super().__init__(model)
229
- self.observations = observations
230
- self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
231
- self.background_model = background_model
308
+ # get negative log-density from the potential function
309
+ @jax.jit
310
+ def nll_fn(position):
311
+ func = potential_fn()
312
+ return func(position)
232
313
 
233
- def numpyro_model(self, prior_distributions: HaikuDict[Distribution]) -> Callable:
314
+ solver = jaxopt.LBFGS(fun=nll_fn, maxiter=10_000)
315
+ params, state = solver.run(param_info.z)
316
+ keys = random.split(random.PRNGKey(rng_key), 3)
234
317
 
235
- def model(observed=True):
318
+ value_flat, unflatten_fun = ravel_pytree(params)
319
+ covariance = jnp.linalg.inv(jax.hessian(lambda p: nll_fn(unflatten_fun(p)))(value_flat))
236
320
 
237
- prior_params = build_prior(prior_distributions, expand_shape=(len(self.observations),))
321
+ samples_flat = jax.random.multivariate_normal(keys[0], value_flat, covariance, shape=(num_samples,))
322
+ samples = jax.vmap(unflatten_fun)(samples_flat.block_until_ready())
323
+ posterior_samples = postprocess_fn()(samples)
238
324
 
239
- for i, (key, observation) in enumerate(self.observations.items()):
325
+ posterior_predictive = Predictive(bayesian_model, posterior_samples)(keys[1], observed=False)
326
+ prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
327
+ log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
240
328
 
241
- params = tree_map(lambda x: x[i], prior_params)
329
+ def sanitize_chain(chain):
330
+ return tree_map(lambda x: x[None, ...], chain)
242
331
 
243
- obs_model = build_numpyro_model(
244
- observation,
245
- self.model,
246
- self.background_model,
247
- name=key + '_'
248
- )
332
+ inference_data = az.from_dict(
333
+ sanitize_chain(posterior_samples),
334
+ prior=sanitize_chain(prior),
335
+ posterior_predictive=sanitize_chain(posterior_predictive),
336
+ log_likelihood=sanitize_chain(log_likelihood),
337
+ )
249
338
 
250
- obs_model(params, observed=observed)
339
+ inference_data = filter_inference_data(inference_data, self._observation_container, self.background_model)
251
340
 
252
- return model
253
- """
341
+ return FitResult(
342
+ self.model,
343
+ self._observation_container,
344
+ inference_data,
345
+ self.model.params,
346
+ background_model=self.background_model,
347
+ )
File without changes