jaxspec 0.2.2.dev0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +5 -5
- jaxspec/analysis/results.py +38 -25
- jaxspec/data/obsconf.py +9 -3
- jaxspec/data/observation.py +3 -1
- jaxspec/data/ogip.py +9 -2
- jaxspec/data/util.py +17 -11
- jaxspec/experimental/interpolator.py +74 -0
- jaxspec/experimental/interpolator_jax.py +79 -0
- jaxspec/experimental/intrument_models.py +159 -0
- jaxspec/experimental/nested_sampler.py +78 -0
- jaxspec/experimental/tabulated.py +264 -0
- jaxspec/fit/__init__.py +3 -0
- jaxspec/{fit.py → fit/_bayesian_model.py} +86 -338
- jaxspec/{_fit → fit}/_build_model.py +42 -6
- jaxspec/fit/_fitter.py +255 -0
- jaxspec/model/abc.py +52 -80
- jaxspec/model/additive.py +14 -5
- jaxspec/model/background.py +17 -14
- jaxspec/model/instrument.py +81 -0
- jaxspec/model/list.py +4 -1
- jaxspec/model/multiplicative.py +32 -12
- jaxspec/util/integrate.py +17 -5
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/METADATA +9 -9
- jaxspec-0.3.0.dist-info/RECORD +42 -0
- jaxspec-0.2.2.dev0.dist-info/RECORD +0 -34
- /jaxspec/{_fit → experimental}/__init__.py +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/WHEEL +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/licenses/LICENSE.md +0 -0
jaxspec/fit/_fitter.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import arviz as az
|
|
7
|
+
import jax
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpyro
|
|
10
|
+
|
|
11
|
+
from jax import random
|
|
12
|
+
from jax.random import PRNGKey
|
|
13
|
+
from numpyro.infer import AIES, ESS, MCMC, NUTS, SVI, Predictive, Trace_ELBO
|
|
14
|
+
from numpyro.infer.autoguide import AutoMultivariateNormal
|
|
15
|
+
|
|
16
|
+
from ..analysis.results import FitResult
|
|
17
|
+
from ._bayesian_model import BayesianModel
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BayesianModelFitter(BayesianModel, ABC):
|
|
21
|
+
def build_inference_data(
|
|
22
|
+
self,
|
|
23
|
+
posterior_samples,
|
|
24
|
+
num_chains: int = 1,
|
|
25
|
+
num_predictive_samples: int = 1000,
|
|
26
|
+
key: PRNGKey = PRNGKey(42),
|
|
27
|
+
use_transformed_model: bool = False,
|
|
28
|
+
filter_inference_data: bool = True,
|
|
29
|
+
) -> az.InferenceData:
|
|
30
|
+
"""
|
|
31
|
+
Build an [InferenceData][arviz.InferenceData] object from posterior samples.
|
|
32
|
+
|
|
33
|
+
Parameters:
|
|
34
|
+
posterior_samples: the samples from the posterior distribution.
|
|
35
|
+
num_chains: the number of chains used to sample the posterior.
|
|
36
|
+
num_predictive_samples: the number of samples to draw from the prior.
|
|
37
|
+
key: the random key used to initialize the sampler.
|
|
38
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
39
|
+
filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
numpyro_model = (
|
|
43
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
keys = random.split(key, 3)
|
|
47
|
+
|
|
48
|
+
posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
|
|
49
|
+
|
|
50
|
+
prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
|
|
51
|
+
keys[1], observed=False
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
55
|
+
|
|
56
|
+
seeded_model = numpyro.handlers.substitute(
|
|
57
|
+
numpyro.handlers.seed(numpyro_model, keys[3]),
|
|
58
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
observations = {
|
|
62
|
+
name: site["value"]
|
|
63
|
+
for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
|
|
64
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def reshape_first_dimension(arr):
|
|
68
|
+
new_dim = arr.shape[0] // num_chains
|
|
69
|
+
new_shape = (num_chains, new_dim) + arr.shape[1:]
|
|
70
|
+
reshaped_array = arr.reshape(new_shape)
|
|
71
|
+
|
|
72
|
+
return reshaped_array
|
|
73
|
+
|
|
74
|
+
posterior_samples = {
|
|
75
|
+
key: reshape_first_dimension(value) for key, value in posterior_samples.items()
|
|
76
|
+
}
|
|
77
|
+
prior = {key: value[None, :] for key, value in prior.items()}
|
|
78
|
+
posterior_predictive = {
|
|
79
|
+
key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
|
|
80
|
+
}
|
|
81
|
+
log_likelihood = {
|
|
82
|
+
key: reshape_first_dimension(value) for key, value in log_likelihood.items()
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
inference_data = az.from_dict(
|
|
86
|
+
posterior_samples,
|
|
87
|
+
prior=prior,
|
|
88
|
+
posterior_predictive=posterior_predictive,
|
|
89
|
+
log_likelihood=log_likelihood,
|
|
90
|
+
observed_data=observations,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return (
|
|
94
|
+
self.filter_inference_data(inference_data) if filter_inference_data else inference_data
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def filter_inference_data(
|
|
98
|
+
self,
|
|
99
|
+
inference_data: az.InferenceData,
|
|
100
|
+
) -> az.InferenceData:
|
|
101
|
+
"""
|
|
102
|
+
Filter the inference data to keep only the relevant parameters for the observations.
|
|
103
|
+
|
|
104
|
+
- Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
|
|
105
|
+
- Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
predictive_parameters = []
|
|
109
|
+
|
|
110
|
+
for key, value in self._observation_container.items():
|
|
111
|
+
if self.background_model is not None:
|
|
112
|
+
predictive_parameters.append(f"obs/~/{key}")
|
|
113
|
+
predictive_parameters.append(f"bkg/~/{key}")
|
|
114
|
+
# predictive_parameters.append(f"ins/~/{key}")
|
|
115
|
+
else:
|
|
116
|
+
predictive_parameters.append(f"obs/~/{key}")
|
|
117
|
+
# predictive_parameters.append(f"ins/~/{key}")
|
|
118
|
+
|
|
119
|
+
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
120
|
+
predictive_parameters
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
parameters = [
|
|
124
|
+
x
|
|
125
|
+
for x in inference_data.posterior.keys()
|
|
126
|
+
if not (x.endswith("_base") or x.startswith("_"))
|
|
127
|
+
]
|
|
128
|
+
inference_data.posterior = inference_data.posterior[parameters]
|
|
129
|
+
inference_data.prior = inference_data.prior[parameters]
|
|
130
|
+
|
|
131
|
+
return inference_data
|
|
132
|
+
|
|
133
|
+
@abstractmethod
|
|
134
|
+
def fit(self, **kwargs) -> FitResult: ...
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class MCMCFitter(BayesianModelFitter):
|
|
138
|
+
"""
|
|
139
|
+
A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
|
|
140
|
+
from numpyro to perform the inference on the model parameters.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
kernel_dict = {
|
|
144
|
+
"nuts": NUTS,
|
|
145
|
+
"aies": AIES,
|
|
146
|
+
"ess": ESS,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
def fit(
|
|
150
|
+
self,
|
|
151
|
+
rng_key: int = 0,
|
|
152
|
+
num_chains: int = len(jax.devices()),
|
|
153
|
+
num_warmup: int = 1000,
|
|
154
|
+
num_samples: int = 1000,
|
|
155
|
+
sampler: Literal["nuts", "aies", "ess"] = "nuts",
|
|
156
|
+
use_transformed_model: bool = True,
|
|
157
|
+
kernel_kwargs: dict = {},
|
|
158
|
+
mcmc_kwargs: dict = {},
|
|
159
|
+
) -> FitResult:
|
|
160
|
+
"""
|
|
161
|
+
Fit the model to the data using a MCMC sampler from numpyro.
|
|
162
|
+
|
|
163
|
+
Parameters:
|
|
164
|
+
rng_key: the random key used to initialize the sampler.
|
|
165
|
+
num_chains: the number of chains to run.
|
|
166
|
+
num_warmup: the number of warmup steps.
|
|
167
|
+
num_samples: the number of samples to draw.
|
|
168
|
+
sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
|
|
169
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
170
|
+
kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
|
|
171
|
+
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
bayesian_model = (
|
|
178
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
chain_kwargs = {
|
|
182
|
+
"num_warmup": num_warmup,
|
|
183
|
+
"num_samples": num_samples,
|
|
184
|
+
"num_chains": num_chains,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
|
|
188
|
+
|
|
189
|
+
mcmc_kwargs = chain_kwargs | mcmc_kwargs
|
|
190
|
+
|
|
191
|
+
if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
|
|
192
|
+
mcmc_kwargs["chain_method"] = "vectorized"
|
|
193
|
+
warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
|
|
194
|
+
|
|
195
|
+
mcmc = MCMC(kernel, **mcmc_kwargs)
|
|
196
|
+
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
197
|
+
|
|
198
|
+
mcmc.run(keys[0])
|
|
199
|
+
|
|
200
|
+
posterior = mcmc.get_samples()
|
|
201
|
+
|
|
202
|
+
inference_data = self.build_inference_data(
|
|
203
|
+
posterior, num_chains=num_chains, use_transformed_model=use_transformed_model
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
return FitResult(
|
|
207
|
+
self,
|
|
208
|
+
inference_data,
|
|
209
|
+
background_model=self.background_model,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class VIFitter(BayesianModelFitter):
|
|
214
|
+
def fit(
|
|
215
|
+
self,
|
|
216
|
+
rng_key: int = 0,
|
|
217
|
+
num_steps: int = 10_000,
|
|
218
|
+
optimizer=numpyro.optim.Adam(step_size=0.0005),
|
|
219
|
+
loss=Trace_ELBO(),
|
|
220
|
+
num_samples: int = 1000,
|
|
221
|
+
guide=None,
|
|
222
|
+
use_transformed_model: bool = True,
|
|
223
|
+
plot_diagnostics: bool = False,
|
|
224
|
+
) -> FitResult:
|
|
225
|
+
bayesian_model = (
|
|
226
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if guide is None:
|
|
230
|
+
guide = AutoMultivariateNormal(bayesian_model)
|
|
231
|
+
|
|
232
|
+
svi = SVI(bayesian_model, guide, optimizer, loss=loss)
|
|
233
|
+
|
|
234
|
+
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
235
|
+
svi_result = svi.run(keys[0], num_steps)
|
|
236
|
+
params = svi_result.params
|
|
237
|
+
|
|
238
|
+
if plot_diagnostics:
|
|
239
|
+
plt.plot(svi_result.losses)
|
|
240
|
+
plt.xlabel("Steps")
|
|
241
|
+
plt.ylabel("ELBO loss")
|
|
242
|
+
plt.semilogy()
|
|
243
|
+
|
|
244
|
+
predictive = Predictive(guide, params=params, num_samples=num_samples)
|
|
245
|
+
posterior = predictive(keys[1])
|
|
246
|
+
|
|
247
|
+
inference_data = self.build_inference_data(
|
|
248
|
+
posterior, num_chains=1, use_transformed_model=use_transformed_model
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return FitResult(
|
|
252
|
+
self,
|
|
253
|
+
inference_data,
|
|
254
|
+
background_model=self.background_model,
|
|
255
|
+
)
|
jaxspec/model/abc.py
CHANGED
|
@@ -11,8 +11,7 @@ import jax
|
|
|
11
11
|
import jax.numpy as jnp
|
|
12
12
|
import jax.scipy as jsp
|
|
13
13
|
import networkx as nx
|
|
14
|
-
|
|
15
|
-
from simpleeval import simple_eval
|
|
14
|
+
import numpy as np
|
|
16
15
|
|
|
17
16
|
from jaxspec.util.typing import PriorDictType
|
|
18
17
|
|
|
@@ -31,14 +30,14 @@ def set_parameters(params: PriorDictType, state: nnx.State) -> nnx.State:
|
|
|
31
30
|
A spectral model with the newly set parameters.
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
|
-
state_dict =
|
|
33
|
+
state_dict = nnx.to_pure_dict(state) # haiku-like 2 level dictionary
|
|
35
34
|
|
|
36
35
|
for key, value in params.items():
|
|
37
36
|
# Split the key to extract the module name and parameter name
|
|
38
37
|
module_name, param_name = key.rsplit("_", 1)
|
|
39
|
-
state_dict["
|
|
38
|
+
state_dict["components"][module_name][param_name] = value
|
|
40
39
|
|
|
41
|
-
|
|
40
|
+
nnx.replace_by_pure_dict(state, state_dict)
|
|
42
41
|
|
|
43
42
|
return state
|
|
44
43
|
|
|
@@ -71,52 +70,16 @@ class Composable(ABC):
|
|
|
71
70
|
|
|
72
71
|
|
|
73
72
|
class SpectralModel(nnx.Module, Composable):
|
|
74
|
-
|
|
75
|
-
# modules: dict
|
|
73
|
+
_graph: nx.DiGraph
|
|
76
74
|
|
|
77
75
|
def __init__(self, graph: nx.DiGraph):
|
|
78
|
-
self.
|
|
79
|
-
self.
|
|
76
|
+
self._graph = graph
|
|
77
|
+
self.components = {}
|
|
78
|
+
self._energy_grid = np.geomspace(0.1, 50, 1000, dtype=np.float64)
|
|
80
79
|
|
|
81
|
-
for node, data in self.
|
|
80
|
+
for node, data in self._graph.nodes(data=True):
|
|
82
81
|
if "component" in data["type"]:
|
|
83
|
-
self.
|
|
84
|
-
|
|
85
|
-
@classmethod
|
|
86
|
-
def from_string(cls, string: str) -> SpectralModel:
|
|
87
|
-
"""
|
|
88
|
-
This constructor enable to build a model from a string. The string should be a valid python expression, with
|
|
89
|
-
the following constraints :
|
|
90
|
-
|
|
91
|
-
* The model components should be defined in the jaxspec.model.list module
|
|
92
|
-
* The model components should be separated by a * or a + (no convolution yet)
|
|
93
|
-
* The model components should be written with their parameters in parentheses
|
|
94
|
-
|
|
95
|
-
Parameters:
|
|
96
|
-
string : The string to parse
|
|
97
|
-
|
|
98
|
-
Examples:
|
|
99
|
-
An absorbed model with a powerlaw and a blackbody:
|
|
100
|
-
|
|
101
|
-
>>> model = SpectralModel.from_string("Tbabs()*(Powerlaw() + Blackbody())")
|
|
102
|
-
"""
|
|
103
|
-
|
|
104
|
-
from .list import model_components
|
|
105
|
-
|
|
106
|
-
return simple_eval(string, functions=model_components)
|
|
107
|
-
|
|
108
|
-
def to_string(self) -> str:
|
|
109
|
-
"""
|
|
110
|
-
This method return the string representation of the model.
|
|
111
|
-
|
|
112
|
-
Examples:
|
|
113
|
-
Build a model from a string and convert it back to a string:
|
|
114
|
-
|
|
115
|
-
>>> model = SpectralModel.from_string("Tbabs()*(Powerlaw() + Blackbody())")
|
|
116
|
-
>>> model.to_string()
|
|
117
|
-
"Tbabs()*(Powerlaw() + Blackbody())"
|
|
118
|
-
"""
|
|
119
|
-
return str(self)
|
|
82
|
+
self.components[data["name"]] = data["component"] # (**data['kwargs'])
|
|
120
83
|
|
|
121
84
|
"""
|
|
122
85
|
def __str__(self) -> str:
|
|
@@ -153,7 +116,7 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
153
116
|
"""
|
|
154
117
|
|
|
155
118
|
composed_graph = compose(
|
|
156
|
-
self.
|
|
119
|
+
self._graph, other._graph, operation=operation, operation_func=operation_func
|
|
157
120
|
)
|
|
158
121
|
|
|
159
122
|
return SpectralModel(composed_graph)
|
|
@@ -161,7 +124,7 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
161
124
|
@classmethod
|
|
162
125
|
def from_component(cls, component):
|
|
163
126
|
node_id = str(uuid4())
|
|
164
|
-
|
|
127
|
+
_graph = nx.DiGraph()
|
|
165
128
|
|
|
166
129
|
node_properties = {
|
|
167
130
|
"type": f"{component.type}_component",
|
|
@@ -170,26 +133,26 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
170
133
|
"depth": 0,
|
|
171
134
|
}
|
|
172
135
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
136
|
+
_graph.add_node(node_id, **node_properties)
|
|
137
|
+
_graph.add_node("out", type="out", depth=1)
|
|
138
|
+
_graph.add_edge(node_id, "out")
|
|
176
139
|
|
|
177
|
-
return cls(
|
|
140
|
+
return cls(_graph)
|
|
178
141
|
|
|
179
142
|
def _find_multiplicative_components(self, node_id):
|
|
180
143
|
"""
|
|
181
144
|
Recursively finds all the multiplicative components connected to the node with the given ID.
|
|
182
145
|
"""
|
|
183
|
-
node = self.
|
|
146
|
+
node = self._graph.nodes[node_id]
|
|
184
147
|
multiplicative_nodes = []
|
|
185
148
|
|
|
186
149
|
if node.get("type") == "mul_operation":
|
|
187
150
|
# Recursively find all the multiplicative components using the predecessors
|
|
188
|
-
predecessors = self.
|
|
151
|
+
predecessors = self._graph.pred[node_id]
|
|
189
152
|
for node_id in predecessors:
|
|
190
|
-
if "multiplicative_component" == self.
|
|
153
|
+
if "multiplicative_component" == self._graph.nodes[node_id].get("type"):
|
|
191
154
|
multiplicative_nodes.append(node_id)
|
|
192
|
-
elif "mul_operation" == self.
|
|
155
|
+
elif "mul_operation" == self._graph.nodes[node_id].get("type"):
|
|
193
156
|
multiplicative_nodes.extend(self._find_multiplicative_components(node_id))
|
|
194
157
|
|
|
195
158
|
return multiplicative_nodes
|
|
@@ -198,8 +161,8 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
198
161
|
def root_nodes(self) -> list[str]:
|
|
199
162
|
return [
|
|
200
163
|
node_id
|
|
201
|
-
for node_id, in_degree in self.
|
|
202
|
-
if in_degree == 0 and ("additive" in self.
|
|
164
|
+
for node_id, in_degree in self._graph.in_degree(self._graph.nodes)
|
|
165
|
+
if in_degree == 0 and ("additive" in self._graph.nodes[node_id].get("type"))
|
|
203
166
|
]
|
|
204
167
|
|
|
205
168
|
@property
|
|
@@ -207,8 +170,8 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
207
170
|
branches = []
|
|
208
171
|
|
|
209
172
|
for root_node_id in self.root_nodes:
|
|
210
|
-
root_node_name = self.
|
|
211
|
-
path = nx.shortest_path(self.
|
|
173
|
+
root_node_name = self._graph.nodes[root_node_id].get("name")
|
|
174
|
+
path = nx.shortest_path(self._graph, source=root_node_id, target="out")
|
|
212
175
|
multiplicative_components = []
|
|
213
176
|
|
|
214
177
|
# Search all multiplicative components connected to this node
|
|
@@ -218,10 +181,12 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
218
181
|
[node_id for node_id in self._find_multiplicative_components(node_id)]
|
|
219
182
|
)
|
|
220
183
|
|
|
184
|
+
multiplicative_components = set(multiplicative_components)
|
|
185
|
+
|
|
221
186
|
branch = ""
|
|
222
187
|
|
|
223
188
|
for multiplicative_node_id in multiplicative_components:
|
|
224
|
-
multiplicative_node_name = self.
|
|
189
|
+
multiplicative_node_name = self._graph.nodes[multiplicative_node_id].get("name")
|
|
225
190
|
branch += f"{multiplicative_node_name}*"
|
|
226
191
|
|
|
227
192
|
branch += f"{root_node_name}"
|
|
@@ -233,12 +198,12 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
233
198
|
continuum = {}
|
|
234
199
|
|
|
235
200
|
## Evaluate the expected contribution for each component
|
|
236
|
-
for node_id in nx.dag.topological_sort(self.
|
|
237
|
-
node = self.
|
|
201
|
+
for node_id in nx.dag.topological_sort(self._graph):
|
|
202
|
+
node = self._graph.nodes[node_id]
|
|
238
203
|
|
|
239
204
|
if node["type"] == "additive_component":
|
|
240
205
|
node_name = node["name"]
|
|
241
|
-
runtime_modules = self.
|
|
206
|
+
runtime_modules = self.components[node_name]
|
|
242
207
|
|
|
243
208
|
if not energy_flux:
|
|
244
209
|
continuum[node_name] = runtime_modules._photon_flux(
|
|
@@ -252,8 +217,8 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
252
217
|
|
|
253
218
|
elif node["type"] == "multiplicative_component":
|
|
254
219
|
node_name = node["name"]
|
|
255
|
-
runtime_modules = self.
|
|
256
|
-
continuum[node_name] = runtime_modules.
|
|
220
|
+
runtime_modules = self.components[node_name]
|
|
221
|
+
continuum[node_name] = runtime_modules._factor(e_low, e_high, n_points=n_points)
|
|
257
222
|
|
|
258
223
|
else:
|
|
259
224
|
pass
|
|
@@ -264,10 +229,10 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
264
229
|
branches = {}
|
|
265
230
|
|
|
266
231
|
for root_node_id in root_nodes:
|
|
267
|
-
root_node_name = self.
|
|
232
|
+
root_node_name = self._graph.nodes[root_node_id].get("name")
|
|
268
233
|
root_continuum = continuum[root_node_name]
|
|
269
234
|
|
|
270
|
-
path = nx.shortest_path(self.
|
|
235
|
+
path = nx.shortest_path(self._graph, source=root_node_id, target="out")
|
|
271
236
|
multiplicative_components = []
|
|
272
237
|
|
|
273
238
|
# Search all multiplicative components connected to this node
|
|
@@ -279,8 +244,8 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
279
244
|
|
|
280
245
|
branch = ""
|
|
281
246
|
|
|
282
|
-
for multiplicative_node_id in multiplicative_components:
|
|
283
|
-
multiplicative_node_name = self.
|
|
247
|
+
for multiplicative_node_id in set(multiplicative_components):
|
|
248
|
+
multiplicative_node_name = self._graph.nodes[multiplicative_node_id].get("name")
|
|
284
249
|
root_continuum *= continuum[multiplicative_node_name]
|
|
285
250
|
branch += f"{multiplicative_node_name}*"
|
|
286
251
|
|
|
@@ -302,7 +267,7 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
302
267
|
Returns:
|
|
303
268
|
A string containing the mermaid representation of the model.
|
|
304
269
|
"""
|
|
305
|
-
return export_to_mermaid(self.
|
|
270
|
+
return export_to_mermaid(self._graph, file)
|
|
306
271
|
|
|
307
272
|
@partial(jax.jit, static_argnums=0, static_argnames=("n_points", "split_branches"))
|
|
308
273
|
def photon_flux(self, params, e_low, e_high, n_points=2, split_branches=False):
|
|
@@ -327,12 +292,12 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
327
292
|
instead.
|
|
328
293
|
"""
|
|
329
294
|
|
|
330
|
-
graphdef,
|
|
331
|
-
|
|
295
|
+
graphdef, parameters, tables = nnx.split(self, nnx.Param, ...)
|
|
296
|
+
parameters = set_parameters(params, parameters)
|
|
332
297
|
|
|
333
|
-
return nnx.
|
|
298
|
+
return nnx.merge(graphdef, parameters, tables).turbo_flux(
|
|
334
299
|
e_low, e_high, n_points=n_points, return_branches=split_branches
|
|
335
|
-
)
|
|
300
|
+
)
|
|
336
301
|
|
|
337
302
|
@partial(jax.jit, static_argnums=0, static_argnames="n_points")
|
|
338
303
|
def energy_flux(self, params, e_low, e_high, n_points=2):
|
|
@@ -357,12 +322,12 @@ class SpectralModel(nnx.Module, Composable):
|
|
|
357
322
|
instead.
|
|
358
323
|
"""
|
|
359
324
|
|
|
360
|
-
graphdef,
|
|
361
|
-
|
|
325
|
+
graphdef, parameters, tables = nnx.split(self, nnx.Param, ...)
|
|
326
|
+
parameters = set_parameters(params, parameters)
|
|
362
327
|
|
|
363
|
-
return nnx.
|
|
328
|
+
return nnx.merge(graphdef, parameters, tables).turbo_flux(
|
|
364
329
|
e_low, e_high, n_points=n_points, energy_flux=True
|
|
365
|
-
)
|
|
330
|
+
)
|
|
366
331
|
|
|
367
332
|
|
|
368
333
|
class ModelComponent(nnx.Module, Composable, ABC):
|
|
@@ -427,6 +392,13 @@ class AdditiveComponent(ModelComponent):
|
|
|
427
392
|
class MultiplicativeComponent(ModelComponent):
|
|
428
393
|
type = "multiplicative"
|
|
429
394
|
|
|
395
|
+
def _factor(self, e_low, e_high, n_points=2):
|
|
396
|
+
energy = jnp.linspace(e_low, e_high, n_points, axis=-1)
|
|
397
|
+
factor = self.factor(energy)
|
|
398
|
+
|
|
399
|
+
return jsp.integrate.trapezoid(factor * energy, jnp.log(energy), axis=-1) / (e_high - e_low)
|
|
400
|
+
# return jnp.mean(factor, axis = -1)
|
|
401
|
+
|
|
430
402
|
def factor(self, energy):
|
|
431
403
|
"""
|
|
432
404
|
Absorption factor applied for a given energy
|
jaxspec/model/additive.py
CHANGED
|
@@ -166,19 +166,28 @@ class Gauss(AdditiveComponent):
|
|
|
166
166
|
self.sigma = nnx.Param(1e-2)
|
|
167
167
|
self.norm = nnx.Param(1.0)
|
|
168
168
|
|
|
169
|
+
def continuum(self, energy):
|
|
170
|
+
return (
|
|
171
|
+
self.norm
|
|
172
|
+
* jsp.stats.norm.pdf(energy, loc=jnp.asarray(self.El), scale=jnp.asarray(self.sigma))
|
|
173
|
+
/ (1 - jsp.special.erf(-self.El / (self.sigma * jnp.sqrt(2))))
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
"""
|
|
169
177
|
def integrated_continuum(self, e_low, e_high):
|
|
170
178
|
return self.norm * (
|
|
171
179
|
jsp.stats.norm.cdf(
|
|
172
180
|
e_high,
|
|
173
|
-
loc=jnp.asarray(self.El
|
|
174
|
-
scale=jnp.asarray(self.sigma
|
|
181
|
+
loc=jnp.asarray(self.El),
|
|
182
|
+
scale=jnp.asarray(self.sigma),
|
|
175
183
|
)
|
|
176
184
|
- jsp.stats.norm.cdf(
|
|
177
185
|
e_low,
|
|
178
|
-
loc=jnp.asarray(self.El
|
|
179
|
-
scale=jnp.asarray(self.sigma
|
|
180
|
-
)
|
|
186
|
+
loc=jnp.asarray(self.El),
|
|
187
|
+
scale=jnp.asarray(self.sigma),
|
|
188
|
+
) #/ (1 - jsp.special.erf(- self.El / (self.sigma * jnp.sqrt(2))))
|
|
181
189
|
)
|
|
190
|
+
"""
|
|
182
191
|
|
|
183
192
|
|
|
184
193
|
class Cutoffpl(AdditiveComponent):
|
jaxspec/model/background.py
CHANGED
|
@@ -5,15 +5,15 @@ import jax.numpy as jnp
|
|
|
5
5
|
import numpyro
|
|
6
6
|
import numpyro.distributions as dist
|
|
7
7
|
|
|
8
|
+
from flax import nnx
|
|
8
9
|
from jax.scipy.integrate import trapezoid
|
|
9
10
|
from numpyro.distributions import Poisson
|
|
10
11
|
from tinygp import GaussianProcess, kernels
|
|
11
12
|
|
|
12
|
-
from .._fit._build_model import build_prior, forward_model
|
|
13
13
|
from .abc import SpectralModel
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class BackgroundModel(ABC):
|
|
16
|
+
class BackgroundModel(ABC, nnx.Module):
|
|
17
17
|
"""
|
|
18
18
|
Handles the background modelling in our spectra. This is handled in a separate class for now
|
|
19
19
|
since backgrounds can be phenomenological models fitted directly on the folded spectrum. This is not the case for
|
|
@@ -42,7 +42,7 @@ class SubtractedBackground(BackgroundModel):
|
|
|
42
42
|
|
|
43
43
|
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
44
44
|
_, observed_counts = observation.out_energies, observation.folded_background.data
|
|
45
|
-
numpyro.deterministic(f"
|
|
45
|
+
numpyro.deterministic(f"bkg/~/{name}", observed_counts)
|
|
46
46
|
|
|
47
47
|
return observed_counts
|
|
48
48
|
|
|
@@ -61,11 +61,11 @@ class BackgroundWithError(BackgroundModel):
|
|
|
61
61
|
_, observed_counts = obs.out_energies, obs.folded_background.data
|
|
62
62
|
alpha = observed_counts + 1
|
|
63
63
|
beta = 1
|
|
64
|
-
countrate = numpyro.sample(f"
|
|
64
|
+
countrate = numpyro.sample(f"bkg/~/_{name}_countrate", dist.Gamma(alpha, rate=beta))
|
|
65
65
|
|
|
66
|
-
with numpyro.plate(f"
|
|
66
|
+
with numpyro.plate(f"bkg/~/{name}_plate", len(observed_counts)):
|
|
67
67
|
numpyro.sample(
|
|
68
|
-
f"
|
|
68
|
+
f"bkg/~/{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
|
|
69
69
|
)
|
|
70
70
|
|
|
71
71
|
return countrate
|
|
@@ -111,27 +111,28 @@ class GaussianProcessBackground(BackgroundModel):
|
|
|
111
111
|
|
|
112
112
|
# The parameters of the GP model
|
|
113
113
|
mean = numpyro.sample(
|
|
114
|
-
f"
|
|
114
|
+
f"bkg/~/_{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0)
|
|
115
115
|
)
|
|
116
|
-
sigma = numpyro.sample(f"
|
|
117
|
-
rho = numpyro.sample(f"
|
|
116
|
+
sigma = numpyro.sample(f"bkg/~/_{name}_sigma", dist.HalfNormal(3.0))
|
|
117
|
+
rho = numpyro.sample(f"bkg/~/_{name}_rho", dist.HalfNormal(10.0))
|
|
118
118
|
|
|
119
119
|
# Set up the kernel and GP objects
|
|
120
120
|
kernel = sigma**2 * self.kernel(rho)
|
|
121
121
|
nodes = jnp.linspace(0, 1, self.n_nodes)
|
|
122
122
|
gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
|
|
123
123
|
|
|
124
|
-
log_rate = numpyro.sample(f"
|
|
124
|
+
log_rate = numpyro.sample(f"bkg/~/_{name}_log_rate_nodes", gp.numpyro_dist())
|
|
125
125
|
|
|
126
126
|
interp_count_rate = jnp.exp(
|
|
127
127
|
jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
|
|
128
128
|
)
|
|
129
|
+
|
|
129
130
|
count_rate = trapezoid(interp_count_rate, energy, axis=0)
|
|
130
131
|
|
|
131
132
|
# Finally, our observation model is Poisson
|
|
132
|
-
with numpyro.plate("
|
|
133
|
+
with numpyro.plate("bkg/~/plate_" + name, len(observed_counts)):
|
|
133
134
|
numpyro.sample(
|
|
134
|
-
f"
|
|
135
|
+
f"bkg/~/{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
|
|
135
136
|
)
|
|
136
137
|
|
|
137
138
|
return count_rate
|
|
@@ -144,15 +145,17 @@ class SpectralModelBackground(BackgroundModel):
|
|
|
144
145
|
self.sparse = sparse
|
|
145
146
|
|
|
146
147
|
def numpyro_model(self, observation, name: str = "", observed=True):
|
|
148
|
+
from jaxspec.fit._build_model import build_prior, forward_model
|
|
149
|
+
|
|
147
150
|
params = build_prior(self.prior, prefix=f"_bkg_{name}_")
|
|
148
151
|
bkg_model = jax.jit(
|
|
149
152
|
lambda par: forward_model(self.spectral_model, par, observation, sparse=self.sparse)
|
|
150
153
|
)
|
|
151
154
|
bkg_countrate = bkg_model(params)
|
|
152
155
|
|
|
153
|
-
with numpyro.plate("
|
|
156
|
+
with numpyro.plate("bkg/~/plate_" + name, len(observation.folded_background)):
|
|
154
157
|
numpyro.sample(
|
|
155
|
-
"
|
|
158
|
+
"bkg/~/" + name,
|
|
156
159
|
Poisson(bkg_countrate),
|
|
157
160
|
obs=observation.folded_background.data if observed else None,
|
|
158
161
|
)
|