jaxspec 0.1.0__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxspec-0.1.0 → jaxspec-0.1.1}/PKG-INFO +2 -2
- {jaxspec-0.1.0 → jaxspec-0.1.1}/pyproject.toml +6 -6
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/fit.py +147 -9
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/abc.py +34 -11
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/additive.py +1 -1
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/multiplicative.py +3 -11
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/typing.py +27 -2
- {jaxspec-0.1.0 → jaxspec-0.1.1}/LICENSE.md +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/README.md +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/analysis/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/analysis/compare.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/analysis/results.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/grouping.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/instrument.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/obsconf.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/observation.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/ogip.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/util.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/_additive/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/_additive/apec.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/_additive/apec_loaders.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/background.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/list.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/scripts/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/scripts/debug.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/__init__.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/abundance.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/integrate.py +0 -0
- {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/online_storage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Home-page: https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
License: MIT
|
|
@@ -24,7 +24,7 @@ Requires-Dist: jaxns (>=2.5.1,<3.0.0)
|
|
|
24
24
|
Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
|
|
25
25
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
|
26
26
|
Requires-Dist: mendeleev (>=0.15,<0.18)
|
|
27
|
-
Requires-Dist: mkdocstrings (>=0.24,<0.
|
|
27
|
+
Requires-Dist: mkdocstrings (>=0.24,<0.27)
|
|
28
28
|
Requires-Dist: networkx (>=3.1,<4.0)
|
|
29
29
|
Requires-Dist: numpy (<2.0.0)
|
|
30
30
|
Requires-Dist: numpyro (>=0.15.2,<0.16.0)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "jaxspec"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.1"
|
|
4
4
|
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
|
|
5
5
|
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -28,7 +28,7 @@ gpjax = "^0.8.0"
|
|
|
28
28
|
jaxopt = "^0.8.1"
|
|
29
29
|
tinygp = "^0.3.0"
|
|
30
30
|
seaborn = "^0.13.1"
|
|
31
|
-
mkdocstrings = ">=0.24,<0.
|
|
31
|
+
mkdocstrings = ">=0.24,<0.27"
|
|
32
32
|
sparse = "^0.15.1"
|
|
33
33
|
optimistix = "^0.0.7"
|
|
34
34
|
scipy = "<1.15"
|
|
@@ -43,8 +43,8 @@ watermark = "^2.4.3"
|
|
|
43
43
|
[tool.poetry.group.docs.dependencies]
|
|
44
44
|
mkdocs = "^1.5.3"
|
|
45
45
|
mkdocs-material = "^9.4.6"
|
|
46
|
-
mkdocstrings = {extras = ["python"], version = ">=0.24,<0.
|
|
47
|
-
mkdocs-jupyter = "
|
|
46
|
+
mkdocstrings = {extras = ["python"], version = ">=0.24,<0.27"}
|
|
47
|
+
mkdocs-jupyter = ">=0.24.6,<0.26.0"
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
[tool.poetry.group.test.dependencies]
|
|
@@ -59,7 +59,7 @@ testbook = "^0.4.2"
|
|
|
59
59
|
|
|
60
60
|
[tool.poetry.group.dev.dependencies]
|
|
61
61
|
pre-commit = "^3.5.0"
|
|
62
|
-
ruff = ">=0.2.1,<0.
|
|
62
|
+
ruff = ">=0.2.1,<0.7.0"
|
|
63
63
|
jupyterlab = "^4.0.7"
|
|
64
64
|
notebook = "^7.0.6"
|
|
65
65
|
ipywidgets = "^8.1.1"
|
|
@@ -118,4 +118,4 @@ requires = ["poetry-core"]
|
|
|
118
118
|
build-backend = "poetry.core.masonry.api"
|
|
119
119
|
|
|
120
120
|
[tool.poetry.scripts]
|
|
121
|
-
jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
|
|
121
|
+
jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import operator
|
|
2
|
+
import warnings
|
|
2
3
|
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
from collections.abc import Callable
|
|
@@ -20,7 +21,7 @@ from jax.tree_util import tree_map
|
|
|
20
21
|
from jax.typing import ArrayLike
|
|
21
22
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
22
23
|
from numpyro.distributions import Distribution, Poisson, TransformedDistribution
|
|
23
|
-
from numpyro.infer import MCMC, NUTS, Predictive
|
|
24
|
+
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
24
25
|
from numpyro.infer.initialization import init_to_value
|
|
25
26
|
from numpyro.infer.inspect import get_model_relations
|
|
26
27
|
from numpyro.infer.reparam import TransformReparam
|
|
@@ -181,7 +182,7 @@ class BayesianModel:
|
|
|
181
182
|
|
|
182
183
|
if not callable(prior_distributions):
|
|
183
184
|
# Validate the entry with pydantic
|
|
184
|
-
prior = PriorDictModel(
|
|
185
|
+
prior = PriorDictModel.from_dict(prior_distributions).nested_dict
|
|
185
186
|
|
|
186
187
|
def prior_distributions_func():
|
|
187
188
|
return build_prior(prior, expand_shape=(len(self.observation_container),))
|
|
@@ -293,6 +294,9 @@ class BayesianModel:
|
|
|
293
294
|
that can be fetched with the [`parameter_names`][jaxspec.fit.BayesianModel.parameter_names].
|
|
294
295
|
"""
|
|
295
296
|
|
|
297
|
+
# This is required as numpyro.infer.util.log_densities does not check parameter validity by itself
|
|
298
|
+
numpyro.enable_validation()
|
|
299
|
+
|
|
296
300
|
@jax.jit
|
|
297
301
|
def log_posterior_prob(constrained_params):
|
|
298
302
|
log_posterior_prob, _ = log_density(
|
|
@@ -350,6 +354,69 @@ class BayesianModel:
|
|
|
350
354
|
|
|
351
355
|
|
|
352
356
|
class BayesianModelFitter(BayesianModel, ABC):
|
|
357
|
+
def build_inference_data(
|
|
358
|
+
self,
|
|
359
|
+
posterior_samples,
|
|
360
|
+
num_chains: int = 1,
|
|
361
|
+
num_predictive_samples: int = 1000,
|
|
362
|
+
key: PRNGKey = PRNGKey(0),
|
|
363
|
+
use_transformed_model: bool = False,
|
|
364
|
+
) -> az.InferenceData:
|
|
365
|
+
"""
|
|
366
|
+
Build an InferenceData object from the posterior samples.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
numpyro_model = (
|
|
370
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
keys = random.split(key, 3)
|
|
374
|
+
|
|
375
|
+
posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
|
|
376
|
+
|
|
377
|
+
prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
|
|
378
|
+
keys[1], observed=False
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
382
|
+
|
|
383
|
+
seeded_model = numpyro.handlers.substitute(
|
|
384
|
+
numpyro.handlers.seed(numpyro_model, keys[3]),
|
|
385
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
observations = {
|
|
389
|
+
name: site["value"]
|
|
390
|
+
for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
|
|
391
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
def reshape_first_dimension(arr):
|
|
395
|
+
new_dim = arr.shape[0] // num_chains
|
|
396
|
+
new_shape = (num_chains, new_dim) + arr.shape[1:]
|
|
397
|
+
reshaped_array = arr.reshape(new_shape)
|
|
398
|
+
|
|
399
|
+
return reshaped_array
|
|
400
|
+
|
|
401
|
+
posterior_samples = {
|
|
402
|
+
key: reshape_first_dimension(value) for key, value in posterior_samples.items()
|
|
403
|
+
}
|
|
404
|
+
prior = {key: value[None, :] for key, value in prior.items()}
|
|
405
|
+
posterior_predictive = {
|
|
406
|
+
key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
|
|
407
|
+
}
|
|
408
|
+
log_likelihood = {
|
|
409
|
+
key: reshape_first_dimension(value) for key, value in log_likelihood.items()
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
return az.from_dict(
|
|
413
|
+
posterior_samples,
|
|
414
|
+
prior=prior,
|
|
415
|
+
posterior_predictive=posterior_predictive,
|
|
416
|
+
log_likelihood=log_likelihood,
|
|
417
|
+
observed_data=observations,
|
|
418
|
+
)
|
|
419
|
+
|
|
353
420
|
@abstractmethod
|
|
354
421
|
def fit(self, **kwargs) -> FitResult: ...
|
|
355
422
|
|
|
@@ -411,18 +478,89 @@ class NUTSFitter(BayesianModelFitter):
|
|
|
411
478
|
|
|
412
479
|
mcmc.run(keys[0])
|
|
413
480
|
|
|
414
|
-
|
|
415
|
-
keys[1], observed=False
|
|
416
|
-
)
|
|
481
|
+
posterior = mcmc.get_samples()
|
|
417
482
|
|
|
418
|
-
|
|
483
|
+
inference_data = filter_inference_data(
|
|
484
|
+
self.build_inference_data(posterior, num_chains=num_chains),
|
|
485
|
+
self.observation_container,
|
|
486
|
+
self.background_model,
|
|
487
|
+
)
|
|
419
488
|
|
|
420
|
-
|
|
421
|
-
|
|
489
|
+
return FitResult(
|
|
490
|
+
self,
|
|
491
|
+
inference_data,
|
|
492
|
+
self.model.params,
|
|
493
|
+
background_model=self.background_model,
|
|
422
494
|
)
|
|
423
495
|
|
|
496
|
+
|
|
497
|
+
class MCMCFitter(BayesianModelFitter):
|
|
498
|
+
"""
|
|
499
|
+
A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
|
|
500
|
+
from numpyro to perform the inference on the model parameters.
|
|
501
|
+
"""
|
|
502
|
+
|
|
503
|
+
kernel_dict = {
|
|
504
|
+
"nuts": NUTS,
|
|
505
|
+
"aies": AIES,
|
|
506
|
+
"ess": ESS,
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
def fit(
|
|
510
|
+
self,
|
|
511
|
+
rng_key: int = 0,
|
|
512
|
+
num_chains: int = len(jax.devices()),
|
|
513
|
+
num_warmup: int = 1000,
|
|
514
|
+
num_samples: int = 1000,
|
|
515
|
+
sampler: Literal["nuts", "aies", "ess"] = "nuts",
|
|
516
|
+
kernel_kwargs: dict = {},
|
|
517
|
+
mcmc_kwargs: dict = {},
|
|
518
|
+
) -> FitResult:
|
|
519
|
+
"""
|
|
520
|
+
Fit the model to the data using a MCMC sampler from numpyro.
|
|
521
|
+
|
|
522
|
+
Parameters:
|
|
523
|
+
rng_key: the random key used to initialize the sampler.
|
|
524
|
+
num_chains: the number of chains to run.
|
|
525
|
+
num_warmup: the number of warmup steps.
|
|
526
|
+
num_samples: the number of samples to draw.
|
|
527
|
+
max_tree_depth: the recursion depth of NUTS sampler.
|
|
528
|
+
target_accept_prob: the target acceptance probability for the NUTS sampler.
|
|
529
|
+
dense_mass: whether to use a dense mass for the NUTS sampler.
|
|
530
|
+
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
534
|
+
"""
|
|
535
|
+
|
|
536
|
+
bayesian_model = self.transformed_numpyro_model
|
|
537
|
+
# bayesian_model = self.numpyro_model(prior_distributions)
|
|
538
|
+
|
|
539
|
+
chain_kwargs = {
|
|
540
|
+
"num_warmup": num_warmup,
|
|
541
|
+
"num_samples": num_samples,
|
|
542
|
+
"num_chains": num_chains,
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
|
|
546
|
+
|
|
547
|
+
mcmc_kwargs = chain_kwargs | mcmc_kwargs
|
|
548
|
+
|
|
549
|
+
if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
|
|
550
|
+
mcmc_kwargs["chain_method"] = "vectorized"
|
|
551
|
+
warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
|
|
552
|
+
|
|
553
|
+
mcmc = MCMC(kernel, **mcmc_kwargs)
|
|
554
|
+
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
555
|
+
|
|
556
|
+
mcmc.run(keys[0])
|
|
557
|
+
|
|
558
|
+
posterior = mcmc.get_samples()
|
|
559
|
+
|
|
424
560
|
inference_data = filter_inference_data(
|
|
425
|
-
|
|
561
|
+
self.build_inference_data(posterior, num_chains=num_chains),
|
|
562
|
+
self.observation_container,
|
|
563
|
+
self.background_model,
|
|
426
564
|
)
|
|
427
565
|
|
|
428
566
|
return FitResult(
|
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
2
6
|
import haiku as hk
|
|
3
7
|
import jax
|
|
4
8
|
import jax.numpy as jnp
|
|
5
9
|
import networkx as nx
|
|
10
|
+
|
|
6
11
|
from haiku._src import base
|
|
7
|
-
from uuid import uuid4
|
|
8
12
|
from jax.scipy.integrate import trapezoid
|
|
9
|
-
from abc import ABC
|
|
10
13
|
from simpleeval import simple_eval
|
|
11
14
|
|
|
12
15
|
|
|
@@ -215,16 +218,18 @@ class SpectralModel:
|
|
|
215
218
|
continuum[node_id] = runtime_modules[node_id].continuum(energies)
|
|
216
219
|
|
|
217
220
|
elif node and node["type"] == "operation":
|
|
218
|
-
component_1 = list(self.graph.in_edges(node_id))[0][0]
|
|
221
|
+
component_1 = list(self.graph.in_edges(node_id))[0][0] # noqa: RUF015
|
|
219
222
|
component_2 = list(self.graph.in_edges(node_id))[1][0]
|
|
220
|
-
continuum[node_id] = node["function"](
|
|
223
|
+
continuum[node_id] = node["function"](
|
|
224
|
+
continuum[component_1], continuum[component_2]
|
|
225
|
+
)
|
|
221
226
|
|
|
222
227
|
if n_points == 2:
|
|
223
|
-
flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]]
|
|
228
|
+
flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
|
|
224
229
|
flux = jnp.stack((flux_1D[:-1], flux_1D[1:]))
|
|
225
230
|
|
|
226
231
|
else:
|
|
227
|
-
flux = continuum[list(self.graph.in_edges("out"))[0][0]]
|
|
232
|
+
flux = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
|
|
228
233
|
|
|
229
234
|
if energy_flux:
|
|
230
235
|
continuum_flux = trapezoid(
|
|
@@ -234,7 +239,9 @@ class SpectralModel:
|
|
|
234
239
|
)
|
|
235
240
|
|
|
236
241
|
else:
|
|
237
|
-
continuum_flux = trapezoid(
|
|
242
|
+
continuum_flux = trapezoid(
|
|
243
|
+
flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0
|
|
244
|
+
)
|
|
238
245
|
|
|
239
246
|
# Iterate from the root nodes to the output node and
|
|
240
247
|
# compute the fine structure contribution for each component
|
|
@@ -249,14 +256,18 @@ class SpectralModel:
|
|
|
249
256
|
path = nx.shortest_path(self.graph, source=root_node_id, target="out")
|
|
250
257
|
nodes_id_in_path = [node_id for node_id in path]
|
|
251
258
|
|
|
252
|
-
flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
|
|
259
|
+
flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
|
|
260
|
+
e_low, e_high
|
|
261
|
+
)
|
|
253
262
|
|
|
254
263
|
multiplicative_nodes = []
|
|
255
264
|
|
|
256
265
|
# Search all multiplicative components connected to this node
|
|
257
266
|
# and apply them at mean energy
|
|
258
267
|
for node_id in nodes_id_in_path[::-1]:
|
|
259
|
-
multiplicative_nodes.extend(
|
|
268
|
+
multiplicative_nodes.extend(
|
|
269
|
+
[node_id for node_id in self.find_multiplicative_components(node_id)]
|
|
270
|
+
)
|
|
260
271
|
|
|
261
272
|
for mul_node in multiplicative_nodes:
|
|
262
273
|
flux_from_component *= runtime_modules[mul_node].continuum(mean_energy)
|
|
@@ -309,7 +320,10 @@ class SpectralModel:
|
|
|
309
320
|
if component.type == "additive":
|
|
310
321
|
|
|
311
322
|
def lam_func(e):
|
|
312
|
-
return
|
|
323
|
+
return (
|
|
324
|
+
component(**kwargs).continuum(e)
|
|
325
|
+
+ component(**kwargs).emission_lines(e, e + 1)[0]
|
|
326
|
+
)
|
|
313
327
|
|
|
314
328
|
elif component.type == "multiplicative":
|
|
315
329
|
|
|
@@ -342,7 +356,9 @@ class SpectralModel:
|
|
|
342
356
|
|
|
343
357
|
return cls(graph, labels)
|
|
344
358
|
|
|
345
|
-
def compose(
|
|
359
|
+
def compose(
|
|
360
|
+
self, other: SpectralModel, operation=None, function=None, name=None
|
|
361
|
+
) -> SpectralModel:
|
|
346
362
|
"""
|
|
347
363
|
This function operate a composition between the operation graph of two models
|
|
348
364
|
1) It fuses the two graphs using which joins at the 'out' nodes
|
|
@@ -524,3 +540,10 @@ class AdditiveComponent(ModelComponent, ABC):
|
|
|
524
540
|
|
|
525
541
|
return jnp.trapz(self(x) * dx, x=t)
|
|
526
542
|
'''
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
class MultiplicativeComponent(ModelComponent, ABC):
|
|
546
|
+
type = "multiplicative"
|
|
547
|
+
|
|
548
|
+
@abstractmethod
|
|
549
|
+
def continuum(self, energy): ...
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
|
|
5
3
|
import haiku as hk
|
|
6
4
|
import jax.numpy as jnp
|
|
7
5
|
import numpy as np
|
|
@@ -10,14 +8,7 @@ from astropy.table import Table
|
|
|
10
8
|
from haiku.initializers import Constant as HaikuConstant
|
|
11
9
|
|
|
12
10
|
from ..util.online_storage import table_manager
|
|
13
|
-
from .abc import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class MultiplicativeComponent(ModelComponent, ABC):
|
|
17
|
-
type = "multiplicative"
|
|
18
|
-
|
|
19
|
-
@abstractmethod
|
|
20
|
-
def continuum(self, energy): ...
|
|
11
|
+
from .abc import MultiplicativeComponent
|
|
21
12
|
|
|
22
13
|
|
|
23
14
|
class Expfac(MultiplicativeComponent):
|
|
@@ -226,6 +217,7 @@ class Tbpcf(MultiplicativeComponent):
|
|
|
226
217
|
|
|
227
218
|
return f * jnp.exp(-nh * sigma) + (1 - f)
|
|
228
219
|
|
|
220
|
+
|
|
229
221
|
class FDcut(MultiplicativeComponent):
|
|
230
222
|
r"""
|
|
231
223
|
A Fermi-Dirac cutoff model.
|
|
@@ -243,4 +235,4 @@ class FDcut(MultiplicativeComponent):
|
|
|
243
235
|
cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
|
|
244
236
|
folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
|
|
245
237
|
|
|
246
|
-
return (1 + jnp.exp((energy - cutoff)/folding)) ** -1
|
|
238
|
+
return (1 + jnp.exp((energy - cutoff) / folding)) ** -1
|
|
@@ -9,6 +9,13 @@ from pydantic import BaseModel, field_validator
|
|
|
9
9
|
PriorDictType = dict[str, dict[str, dist.Distribution | ArrayLike]]
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
def is_flat_dict(input_data: dict[str, Any]) -> bool:
|
|
13
|
+
"""
|
|
14
|
+
Check if the input data is a flat dictionary with string keys and non-dictionary values.
|
|
15
|
+
"""
|
|
16
|
+
return all(isinstance(k, str) and not isinstance(v, dict) for k, v in input_data.items())
|
|
17
|
+
|
|
18
|
+
|
|
12
19
|
class PriorDictModel(BaseModel):
|
|
13
20
|
"""
|
|
14
21
|
Pydantic model for a nested dictionary of NumPyro distributions or JAX arrays.
|
|
@@ -21,6 +28,23 @@ class PriorDictModel(BaseModel):
|
|
|
21
28
|
class Config: # noqa D106
|
|
22
29
|
arbitrary_types_allowed = True
|
|
23
30
|
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_dict(cls, input_prior: dict[str, Any]):
|
|
33
|
+
if is_flat_dict(input_prior):
|
|
34
|
+
nested_dict = {}
|
|
35
|
+
|
|
36
|
+
for key, obj in input_prior.items():
|
|
37
|
+
component, component_number, *parameter = key.split("_")
|
|
38
|
+
|
|
39
|
+
sub_dict = nested_dict.get(f"{component}_{component_number}", {})
|
|
40
|
+
sub_dict["_".join(parameter)] = obj
|
|
41
|
+
|
|
42
|
+
nested_dict[f"{component}_{component_number}"] = sub_dict
|
|
43
|
+
|
|
44
|
+
return cls(nested_dict=nested_dict)
|
|
45
|
+
|
|
46
|
+
return cls(nested_dict=input_prior)
|
|
47
|
+
|
|
24
48
|
@field_validator("nested_dict", mode="before")
|
|
25
49
|
def check_and_cast_nested_dict(cls, value: dict[str, Any]):
|
|
26
50
|
if not isinstance(value, dict):
|
|
@@ -35,9 +59,10 @@ class PriorDictModel(BaseModel):
|
|
|
35
59
|
try:
|
|
36
60
|
# Attempt to cast to JAX array
|
|
37
61
|
value[key][inner_key] = jnp.array(obj, dtype=float)
|
|
62
|
+
|
|
38
63
|
except Exception as e:
|
|
39
64
|
raise ValueError(
|
|
40
|
-
f'The value for key "{inner_key}" in
|
|
41
|
-
f"
|
|
65
|
+
f'The value for key "{inner_key}" in {key} be a NumPyro '
|
|
66
|
+
f"distribution or castable to JAX array. Error: {e}"
|
|
42
67
|
)
|
|
43
68
|
return value
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|