jaxspec 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +109 -62
- jaxspec/data/util.py +100 -79
- jaxspec/fit.py +78 -87
- jaxspec/model/additive.py +167 -42
- jaxspec/model/multiplicative.py +55 -28
- jaxspec/util/online_storage.py +13 -0
- {jaxspec-0.0.7.dist-info → jaxspec-0.0.8.dist-info}/METADATA +7 -6
- {jaxspec-0.0.7.dist-info → jaxspec-0.0.8.dist-info}/RECORD +10 -14
- {jaxspec-0.0.7.dist-info → jaxspec-0.0.8.dist-info}/WHEEL +1 -1
- jaxspec/data/example_data/MOS1.pha +0 -46
- jaxspec/data/example_data/MOS2.pha +0 -42
- jaxspec/data/example_data/PN.pha +1 -293
- jaxspec/data/example_data/fakeit.pha +1 -335
- {jaxspec-0.0.7.dist-info → jaxspec-0.0.8.dist-info}/LICENSE.md +0 -0
jaxspec/fit.py
CHANGED
|
@@ -19,6 +19,7 @@ from numpyro.contrib.nested_sampling import NestedSampler
|
|
|
19
19
|
from numpyro.distributions import Distribution, Poisson, TransformedDistribution
|
|
20
20
|
from numpyro.infer import MCMC, NUTS, Predictive
|
|
21
21
|
from numpyro.infer.initialization import init_to_value
|
|
22
|
+
from numpyro.infer.inspect import get_model_relations
|
|
22
23
|
from numpyro.infer.reparam import TransformReparam
|
|
23
24
|
from numpyro.infer.util import constrain_fn
|
|
24
25
|
from scipy.stats import Covariance, multivariate_normal
|
|
@@ -27,7 +28,7 @@ from .analysis.results import FitResult
|
|
|
27
28
|
from .data import ObsConfiguration
|
|
28
29
|
from .model.abc import SpectralModel
|
|
29
30
|
from .model.background import BackgroundModel
|
|
30
|
-
from .util import catchtime
|
|
31
|
+
from .util import catchtime
|
|
31
32
|
from .util.typing import PriorDictModel, PriorDictType
|
|
32
33
|
|
|
33
34
|
|
|
@@ -39,17 +40,16 @@ def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
|
|
|
39
40
|
parameters = dict(hk.data_structures.to_haiku_dict(prior))
|
|
40
41
|
|
|
41
42
|
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
)
|
|
43
|
+
if isinstance(sample, Distribution):
|
|
44
|
+
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
|
|
45
|
+
|
|
46
|
+
elif isinstance(sample, ArrayLike):
|
|
47
|
+
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
|
|
52
|
+
)
|
|
53
53
|
|
|
54
54
|
return parameters
|
|
55
55
|
|
|
@@ -90,7 +90,7 @@ def build_numpyro_model_for_single_obs(
|
|
|
90
90
|
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
91
91
|
numpyro.sample(
|
|
92
92
|
"obs_" + name,
|
|
93
|
-
Poisson(countrate + bkg_countrate
|
|
93
|
+
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
94
94
|
obs=obs.folded_counts.data if observed else None,
|
|
95
95
|
)
|
|
96
96
|
|
|
@@ -148,14 +148,15 @@ class CountForwardModel(hk.Module):
|
|
|
148
148
|
return jnp.clip(expected_counts, a_min=1e-6)
|
|
149
149
|
|
|
150
150
|
|
|
151
|
-
class
|
|
151
|
+
class BayesianModel:
|
|
152
152
|
"""
|
|
153
|
-
|
|
153
|
+
Class to fit a model to a given set of observation.
|
|
154
154
|
"""
|
|
155
155
|
|
|
156
156
|
def __init__(
|
|
157
157
|
self,
|
|
158
158
|
model: SpectralModel,
|
|
159
|
+
prior_distributions: PriorDictType | Callable,
|
|
159
160
|
observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
|
|
160
161
|
background_model: BackgroundModel = None,
|
|
161
162
|
sparsify_matrix: bool = False,
|
|
@@ -163,9 +164,10 @@ class ModelFitter(ABC):
|
|
|
163
164
|
"""
|
|
164
165
|
Initialize the fitter.
|
|
165
166
|
|
|
166
|
-
Parameters
|
|
167
|
-
----------
|
|
167
|
+
Parameters:
|
|
168
168
|
model: the spectral model to fit.
|
|
169
|
+
prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
|
|
170
|
+
callable function that returns parameter samples.
|
|
169
171
|
observations: the observations to fit the model to.
|
|
170
172
|
background_model: the background model to fit.
|
|
171
173
|
sparsify_matrix: whether to sparsify the transfer matrix.
|
|
@@ -176,8 +178,20 @@ class ModelFitter(ABC):
|
|
|
176
178
|
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
177
179
|
self.sparse = sparsify_matrix
|
|
178
180
|
|
|
181
|
+
if not callable(prior_distributions):
|
|
182
|
+
# Validate the entry with pydantic
|
|
183
|
+
prior = PriorDictModel(nested_dict=prior_distributions).nested_dict
|
|
184
|
+
|
|
185
|
+
def prior_distributions_func():
|
|
186
|
+
return build_prior(prior, expand_shape=(len(self.observation_container),))
|
|
187
|
+
|
|
188
|
+
else:
|
|
189
|
+
prior_distributions_func = prior_distributions
|
|
190
|
+
|
|
191
|
+
self.prior_distributions_func = prior_distributions_func
|
|
192
|
+
|
|
179
193
|
@property
|
|
180
|
-
def
|
|
194
|
+
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
181
195
|
"""
|
|
182
196
|
The observations used in the fit as a dictionary of observations.
|
|
183
197
|
"""
|
|
@@ -194,36 +208,21 @@ class ModelFitter(ABC):
|
|
|
194
208
|
else:
|
|
195
209
|
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
196
210
|
|
|
197
|
-
|
|
211
|
+
@property
|
|
212
|
+
def numpyro_model(self) -> Callable:
|
|
198
213
|
"""
|
|
199
214
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
200
215
|
|
|
201
|
-
Parameters
|
|
202
|
-
----------
|
|
203
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
204
|
-
|
|
205
216
|
Returns:
|
|
206
217
|
-------
|
|
207
218
|
A model function that can be used with numpyro.
|
|
208
219
|
"""
|
|
209
220
|
|
|
210
|
-
if not callable(prior_distributions):
|
|
211
|
-
# Validate the entry with pydantic
|
|
212
|
-
prior_distributions = PriorDictModel(nested_dict=prior_distributions).nested_dict
|
|
213
|
-
|
|
214
|
-
def prior_distributions_func():
|
|
215
|
-
return build_prior(
|
|
216
|
-
prior_distributions, expand_shape=(len(self._observation_container),)
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
else:
|
|
220
|
-
prior_distributions_func = prior_distributions
|
|
221
|
-
|
|
222
221
|
def model(observed=True):
|
|
223
|
-
prior_params = prior_distributions_func()
|
|
222
|
+
prior_params = self.prior_distributions_func()
|
|
224
223
|
|
|
225
224
|
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
226
|
-
for i, (key, observation) in enumerate(self.
|
|
225
|
+
for i, (key, observation) in enumerate(self.observation_container.items()):
|
|
227
226
|
# We expect that prior_params contains an array of parameters for each observation
|
|
228
227
|
# They can be identical or different for each observation
|
|
229
228
|
params = tree_map(lambda x: x[i], prior_params)
|
|
@@ -236,22 +235,29 @@ class ModelFitter(ABC):
|
|
|
236
235
|
|
|
237
236
|
return model
|
|
238
237
|
|
|
239
|
-
|
|
238
|
+
@property
|
|
239
|
+
def transformed_numpyro_model(self) -> Callable:
|
|
240
240
|
transform_dict = {}
|
|
241
241
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
242
|
+
relations = get_model_relations(self.numpyro_model)
|
|
243
|
+
distributions = {
|
|
244
|
+
parameter: getattr(numpyro.distributions, value, None)
|
|
245
|
+
for parameter, value in relations["sample_dist"].items()
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
for parameter, distribution in distributions.items():
|
|
249
|
+
if isinstance(distribution, TransformedDistribution):
|
|
250
|
+
transform_dict[parameter] = TransformReparam()
|
|
251
|
+
|
|
252
|
+
return numpyro.handlers.reparam(self.numpyro_model, config=transform_dict)
|
|
245
253
|
|
|
246
|
-
return numpyro.handlers.reparam(
|
|
247
|
-
self.numpyro_model(prior_distributions), config=transform_dict
|
|
248
|
-
)
|
|
249
254
|
|
|
255
|
+
class BayesianModelFitter(BayesianModel, ABC):
|
|
250
256
|
@abstractmethod
|
|
251
|
-
def fit(self,
|
|
257
|
+
def fit(self, **kwargs) -> FitResult: ...
|
|
252
258
|
|
|
253
259
|
|
|
254
|
-
class
|
|
260
|
+
class NUTSFitter(BayesianModelFitter):
|
|
255
261
|
"""
|
|
256
262
|
A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
|
|
257
263
|
from numpyro to perform the inference on the model parameters.
|
|
@@ -259,7 +265,6 @@ class BayesianFitter(ModelFitter):
|
|
|
259
265
|
|
|
260
266
|
def fit(
|
|
261
267
|
self,
|
|
262
|
-
prior_distributions: PriorDictType,
|
|
263
268
|
rng_key: int = 0,
|
|
264
269
|
num_chains: int = len(jax.devices()),
|
|
265
270
|
num_warmup: int = 1000,
|
|
@@ -273,9 +278,7 @@ class BayesianFitter(ModelFitter):
|
|
|
273
278
|
"""
|
|
274
279
|
Fit the model to the data using NUTS sampler from numpyro.
|
|
275
280
|
|
|
276
|
-
Parameters
|
|
277
|
-
----------
|
|
278
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
281
|
+
Parameters:
|
|
279
282
|
rng_key: the random key used to initialize the sampler.
|
|
280
283
|
num_chains: the number of chains to run.
|
|
281
284
|
num_warmup: the number of warmup steps.
|
|
@@ -286,11 +289,10 @@ class BayesianFitter(ModelFitter):
|
|
|
286
289
|
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
287
290
|
|
|
288
291
|
Returns:
|
|
289
|
-
-------
|
|
290
292
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
291
293
|
"""
|
|
292
294
|
|
|
293
|
-
bayesian_model = self.transformed_numpyro_model
|
|
295
|
+
bayesian_model = self.transformed_numpyro_model
|
|
294
296
|
# bayesian_model = self.numpyro_model(prior_distributions)
|
|
295
297
|
|
|
296
298
|
chain_kwargs = {
|
|
@@ -311,28 +313,30 @@ class BayesianFitter(ModelFitter):
|
|
|
311
313
|
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
312
314
|
|
|
313
315
|
mcmc.run(keys[0])
|
|
316
|
+
|
|
314
317
|
posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
|
|
315
318
|
keys[1], observed=False
|
|
316
319
|
)
|
|
320
|
+
|
|
317
321
|
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
|
|
322
|
+
|
|
318
323
|
inference_data = az.from_numpyro(
|
|
319
324
|
mcmc, prior=prior, posterior_predictive=posterior_predictive
|
|
320
325
|
)
|
|
321
326
|
|
|
322
327
|
inference_data = filter_inference_data(
|
|
323
|
-
inference_data, self.
|
|
328
|
+
inference_data, self.observation_container, self.background_model
|
|
324
329
|
)
|
|
325
330
|
|
|
326
331
|
return FitResult(
|
|
327
|
-
self
|
|
328
|
-
self._observation_container,
|
|
332
|
+
self,
|
|
329
333
|
inference_data,
|
|
330
334
|
self.model.params,
|
|
331
335
|
background_model=self.background_model,
|
|
332
336
|
)
|
|
333
337
|
|
|
334
338
|
|
|
335
|
-
class MinimizationFitter(
|
|
339
|
+
class MinimizationFitter(BayesianModelFitter):
|
|
336
340
|
"""
|
|
337
341
|
A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
|
|
338
342
|
algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
|
|
@@ -342,7 +346,6 @@ class MinimizationFitter(ModelFitter):
|
|
|
342
346
|
|
|
343
347
|
def fit(
|
|
344
348
|
self,
|
|
345
|
-
prior_distributions: PriorDictType,
|
|
346
349
|
rng_key: int = 0,
|
|
347
350
|
num_iter_max: int = 100_000,
|
|
348
351
|
num_samples: int = 1_000,
|
|
@@ -353,27 +356,24 @@ class MinimizationFitter(ModelFitter):
|
|
|
353
356
|
"""
|
|
354
357
|
Fit the model to the data using L-BFGS algorithm.
|
|
355
358
|
|
|
356
|
-
Parameters
|
|
357
|
-
----------
|
|
358
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
359
|
+
Parameters:
|
|
359
360
|
rng_key: the random key used to initialize the sampler.
|
|
360
361
|
num_iter_max: the maximum number of iteration in the minimization algorithm.
|
|
361
362
|
num_samples: the number of sample to draw from the best-fit covariance.
|
|
362
363
|
|
|
363
364
|
Returns:
|
|
364
|
-
-------
|
|
365
365
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
366
366
|
"""
|
|
367
367
|
|
|
368
|
-
bayesian_model = self.numpyro_model
|
|
368
|
+
bayesian_model = self.numpyro_model
|
|
369
369
|
keys = jax.random.split(PRNGKey(rng_key), 4)
|
|
370
370
|
|
|
371
371
|
if init_params is not None:
|
|
372
372
|
# We initialize the parameters by randomly sampling from the prior
|
|
373
373
|
local_keys = jax.random.split(keys[0], 2)
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
374
|
+
|
|
375
|
+
with numpyro.handlers.seed(rng_seed=local_keys[0]):
|
|
376
|
+
starting_value = self.prior_distributions_func()
|
|
377
377
|
|
|
378
378
|
# We update the starting value with the provided init_params
|
|
379
379
|
for m, n, val in hk.data_structures.traverse(init_params):
|
|
@@ -482,19 +482,18 @@ class MinimizationFitter(ModelFitter):
|
|
|
482
482
|
)
|
|
483
483
|
|
|
484
484
|
inference_data = filter_inference_data(
|
|
485
|
-
inference_data, self.
|
|
485
|
+
inference_data, self.observation_container, self.background_model
|
|
486
486
|
)
|
|
487
487
|
|
|
488
488
|
return FitResult(
|
|
489
|
-
self
|
|
490
|
-
self._observation_container,
|
|
489
|
+
self,
|
|
491
490
|
inference_data,
|
|
492
491
|
self.model.params,
|
|
493
492
|
background_model=self.background_model,
|
|
494
493
|
)
|
|
495
494
|
|
|
496
495
|
|
|
497
|
-
class NestedSamplingFitter(
|
|
496
|
+
class NestedSamplingFitter(BayesianModelFitter):
|
|
498
497
|
r"""
|
|
499
498
|
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
500
499
|
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
@@ -504,18 +503,16 @@ class NestedSamplingFitter(ModelFitter):
|
|
|
504
503
|
|
|
505
504
|
def fit(
|
|
506
505
|
self,
|
|
507
|
-
prior_distributions: PriorDictType,
|
|
508
506
|
rng_key: int = 0,
|
|
509
|
-
num_parallel_workers: int = len(jax.devices()),
|
|
510
507
|
num_samples: int = 1000,
|
|
511
508
|
plot_diagnostics=False,
|
|
509
|
+
termination_kwargs: dict | None = None,
|
|
512
510
|
verbose=True,
|
|
513
511
|
) -> FitResult:
|
|
514
512
|
"""
|
|
515
513
|
Fit the model to the data using the Phantom-Powered nested sampling algorithm.
|
|
516
514
|
|
|
517
515
|
Parameters:
|
|
518
|
-
prior_distributions: a nested dictionary containing the prior distributions for the model parameters.
|
|
519
516
|
rng_key: the random key used to initialize the sampler.
|
|
520
517
|
num_samples: the number of samples to draw.
|
|
521
518
|
|
|
@@ -523,39 +520,34 @@ class NestedSamplingFitter(ModelFitter):
|
|
|
523
520
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
524
521
|
"""
|
|
525
522
|
|
|
526
|
-
bayesian_model = self.transformed_numpyro_model
|
|
523
|
+
bayesian_model = self.transformed_numpyro_model
|
|
527
524
|
keys = random.split(random.PRNGKey(rng_key), 4)
|
|
528
525
|
|
|
529
526
|
ns = NestedSampler(
|
|
530
527
|
bayesian_model,
|
|
531
528
|
constructor_kwargs=dict(
|
|
532
|
-
num_parallel_workers=
|
|
529
|
+
num_parallel_workers=1,
|
|
533
530
|
verbose=verbose,
|
|
534
531
|
difficult_model=True,
|
|
535
|
-
|
|
536
|
-
# num_live_points=10_000,
|
|
537
|
-
# init_efficiency_threshold=0.5,
|
|
532
|
+
max_samples=1e6,
|
|
538
533
|
parameter_estimation=True,
|
|
534
|
+
num_live_points=1_000,
|
|
539
535
|
),
|
|
540
|
-
termination_kwargs=dict(
|
|
536
|
+
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
541
537
|
)
|
|
542
538
|
|
|
543
539
|
ns.run(keys[0])
|
|
544
540
|
|
|
545
|
-
self.ns = ns
|
|
546
|
-
|
|
547
541
|
if plot_diagnostics:
|
|
548
542
|
ns.diagnostics()
|
|
549
543
|
|
|
550
|
-
posterior_samples = ns.get_samples(keys[1], num_samples=num_samples
|
|
544
|
+
posterior_samples = ns.get_samples(keys[1], num_samples=num_samples)
|
|
551
545
|
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
552
546
|
posterior_predictive = Predictive(bayesian_model, posterior_samples)(
|
|
553
547
|
keys[2], observed=False
|
|
554
548
|
)
|
|
555
549
|
|
|
556
|
-
prior = Predictive(bayesian_model, num_samples=num_samples
|
|
557
|
-
keys[3], observed=False
|
|
558
|
-
)
|
|
550
|
+
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
559
551
|
|
|
560
552
|
seeded_model = numpyro.handlers.substitute(
|
|
561
553
|
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
@@ -583,12 +575,11 @@ class NestedSamplingFitter(ModelFitter):
|
|
|
583
575
|
)
|
|
584
576
|
|
|
585
577
|
inference_data = filter_inference_data(
|
|
586
|
-
inference_data, self.
|
|
578
|
+
inference_data, self.observation_container, self.background_model
|
|
587
579
|
)
|
|
588
580
|
|
|
589
581
|
return FitResult(
|
|
590
|
-
self
|
|
591
|
-
self._observation_container,
|
|
582
|
+
self,
|
|
592
583
|
inference_data,
|
|
593
584
|
self.model.params,
|
|
594
585
|
background_model=self.background_model,
|